Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
RealType,
TFloat,
TFloatHighPrecision,
TFloatOrBFloat16,
TInt,
TReal,
TRealOrUInt8,
Expand Down Expand Up @@ -3564,14 +3563,14 @@ def aten_flipud(self: TensorType) -> TensorType:


@torch_op("aten::floor", traceable=True)
def aten_floor(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_floor(self: TFloat) -> TFloat:
"""floor(Tensor self) -> Tensor"""

return op.Floor(self)


@torch_op("math::floor", traceable=True)
def python_math_floor(self: TFloatOrBFloat16) -> TInt:
def python_math_floor(self: TFloat) -> TInt:
"""floor(Tensor self) -> Tensor"""
floor = op.Floor(self)
return op.Cast(floor, to=INT64.dtype)
Expand Down Expand Up @@ -4533,7 +4532,7 @@ def aten_isfinite(self: TFloatHighPrecision) -> BOOL:


@torch_op("aten::isinf")
def aten_isinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isinf(self: TFloat) -> BOOL:
"""isinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4542,14 +4541,14 @@ def aten_isinf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isnan")
def aten_isnan(self: TFloatOrBFloat16) -> BOOL:
def aten_isnan(self: TFloat) -> BOOL:
"""isnan(Tensor self) -> Tensor"""

return op.IsNaN(self)


@torch_op("aten::isneginf")
def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:
def aten_isneginf(self: TFloat) -> BOOL:
"""isneginf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand All @@ -4558,7 +4557,7 @@ def aten_isneginf(self: TFloatOrBFloat16) -> BOOL:


@torch_op("aten::isposinf")
def aten_isposinf(self: TFloatOrBFloat16) -> BOOL:
def aten_isposinf(self: TFloat) -> BOOL:
"""isposinf(Tensor self) -> Tensor"""

# Added Cast inside the function so it can support all real dtypes naturally
Expand Down Expand Up @@ -4778,42 +4777,42 @@ def aten_linspace(


@torch_op("aten::log", traceable=True)
def aten_log(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log(self: TFloat) -> TFloat:
"""log(Tensor self) -> Tensor"""

return op.Log(self)


@torch_op("aten::log10", traceable=True)
def aten_log10(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log10(self: TFloat) -> TFloat:
"""log10(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(10.0), self))


@torch_op("aten::log1p")
def aten_log1p(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log1p(self: TFloat) -> TFloat:
"""log1p(Tensor self) -> Tensor"""

return op.Log(op.Add(self, 1.0))


@torch_op("aten::log2", traceable=True)
def aten_log2(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log2(self: TFloat) -> TFloat:
"""log2(Tensor self) -> Tensor"""

return op.Div(op.Log(self), op.CastLike(op.Log(2.0), self))


@torch_op("aten::logaddexp", traceable=True)
def aten_logaddexp(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp(Tensor self, Tensor other) -> Tensor"""

return op.Log(op.Add(op.Exp(self), op.Exp(other)))


@torch_op("aten::logaddexp2", traceable=True)
def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_logaddexp2(self: TFloat, other: TFloat) -> TFloat:
"""logaddexp2(Tensor self, Tensor other) -> Tensor"""
two = op.CastLike(2.0, self)
summation = op.Add(op.Pow(two, self), op.Pow(two, other))
Expand All @@ -4822,7 +4821,7 @@ def aten_logaddexp2(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOr


@torch_op("aten::logcumsumexp", traceable=True)
def aten_logcumsumexp(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_logcumsumexp(self: TFloat, dim: int) -> TFloat:
"""logcumsumexp(Tensor self, int dim) -> Tensor"""

if IsScalar(self):
Expand Down Expand Up @@ -4908,12 +4907,12 @@ def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:


@torch_op("aten::logit", private=True)
def _aten_logit_onnx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def _aten_logit_onnx(self: TFloat) -> TFloat:
return op.Log(op.Div(self, op.Sub(1.0, self)))


@torch_op("aten::logit", private=True)
def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat16:
def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat:
eps = op.CastLike(eps, self)
one = op.CastLike(1.0, self)
temporary_self = op.Where(self <= one - eps, self, one - eps)
Expand All @@ -4923,7 +4922,7 @@ def _aten_logit_clamp_onnx(self: TFloatOrBFloat16, eps: float) -> TFloatOrBFloat


@torch_op("aten::logit", trace_only=True)
def aten_logit(self: TFloatOrBFloat16, eps: Optional[float] = None) -> TFloatOrBFloat16:
def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat:
"""logit(Tensor self, float? eps=None) -> Tensor"""
if eps is None:
return _aten_logit_onnx(self)
Expand Down Expand Up @@ -6041,9 +6040,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType:


@torch_op("aten::native_dropout", trace_only=True)
def aten_native_dropout(
input: TFloatOrBFloat16, p: float, train: bool = True
) -> Tuple[TFloatOrBFloat16, BOOL]:
def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]:
"""native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)"""

result, mask = op.Dropout(input, p, train)
Expand Down Expand Up @@ -7055,7 +7052,7 @@ def aten_real(self: TensorType) -> TensorType:


@torch_op("aten::reciprocal")
def aten_reciprocal(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

return op.Reciprocal(self)
Expand All @@ -7074,7 +7071,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

# TODO(justinchuby): Improve fp16 precision by following the logic in
Expand Down Expand Up @@ -7355,7 +7352,7 @@ def aten_rrelu(


@torch_op("aten::rsqrt", traceable=True)
def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_rsqrt(self: TFloat) -> TFloat:
"""rsqrt(Tensor self) -> Tensor"""

return op.Reciprocal(op.Sqrt(self))
Expand Down Expand Up @@ -7562,7 +7559,7 @@ def aten_sgn(self: TensorType) -> TensorType:


@torch_op("aten::sigmoid", traceable=True)
def aten_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sigmoid(self: TFloat) -> TFloat:
"""sigmoid(Tensor self) -> Tensor"""

return op.Sigmoid(self)
Expand Down Expand Up @@ -7724,7 +7721,7 @@ def aten_smm(self: TensorType, mat2: TensorType) -> TensorType:


@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True)
def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrBFloat16:
def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand All @@ -7741,7 +7738,7 @@ def aten_softmax(self: TFloatOrBFloat16, dim: int, dtype: int = -1) -> TFloatOrB


@torch_op(("aten::softmax.int", "aten::special_softmax"), traceable=True)
def aten_softmax_no_dtype(self: TFloatOrBFloat16, dim: int) -> TFloatOrBFloat16:
def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat:
"""softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -7812,7 +7809,7 @@ def aten_split_with_sizes_copy(


@torch_op("aten::sqrt", traceable=True)
def aten_sqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_sqrt(self: TFloat) -> TFloat:
"""sqrt(Tensor self) -> Tensor"""

return op.Sqrt(self)
Expand Down Expand Up @@ -8402,7 +8399,7 @@ def aten_triu_indices(row: int, col: int, offset: int = 0) -> TensorType:


@torch_op("aten::trunc")
def aten_trunc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_trunc(self: TFloat) -> TFloat:
"""trunc(Tensor self) -> Tensor"""

# Reference https://github.com/onnx/onnx/issues/4588#issuecomment-1463970126
Expand Down
11 changes: 5 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
Expand Down Expand Up @@ -364,13 +363,13 @@ def aten_conv_depthwise3d(

@torch_op("aten::cross_entropy_loss", traceable=True)
def aten_cross_entropy_loss(
self: TFloatOrBFloat16,
self: TFloat,
target: IntType,
weight: Optional[TFloatOrBFloat16] = None,
weight: Optional[TFloat] = None,
reduction: int = 1, # default is 'mean'
ignore_index: int = -100,
label_smoothing: float = 0.0, # this was ignored due to ONNX not support
) -> TFloatOrBFloat16:
) -> TFloat:
"""cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor"""

if reduction == 0: # "none"
Expand Down Expand Up @@ -812,7 +811,7 @@ def aten_l1_loss(self: TensorType, target: TensorType, reduction: int = 1) -> Te


@torch_op("aten::leaky_relu")
def aten_leaky_relu(self: TFloatOrBFloat16, negative_slope: float = 0.01) -> TFloatOrBFloat16:
def aten_leaky_relu(self: TFloat, negative_slope: float = 0.01) -> TFloat:
"""leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor"""

return op.LeakyRelu(self, alpha=negative_slope)
Expand Down Expand Up @@ -850,7 +849,7 @@ def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:


@torch_op("aten::log_sigmoid")
def aten_log_sigmoid(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_log_sigmoid(self: TFloat) -> TFloat:
"""log_sigmoid(Tensor self) -> Tensor"""

return op.Log(op.Sigmoid(self))
Expand Down
16 changes: 7 additions & 9 deletions onnxscript/function_libs/torch_lib/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TFloatOrBFloat16
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType

Expand Down Expand Up @@ -92,21 +92,21 @@ def aten_special_entr(self: TensorType) -> TensorType:


@torch_op(("aten::erf", "aten::special_erf"))
def aten_special_erf(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erf(self: TFloat) -> TFloat:
"""erf(Tensor self) -> Tensor"""

return op.Erf(self)


@torch_op(("aten::erfc", "aten::special_erfc"))
def aten_special_erfc(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfc(self: TFloat) -> TFloat:
"""erfc(Tensor self) -> Tensor"""

return op.Sub(1, op.Erf(self))


@torch_op("aten::special_erfcx")
def aten_special_erfcx(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_erfcx(self: TFloat) -> TFloat:
"""special_erfcx(Tensor self) -> Tensor"""

return op.Mul(op.Exp(op.Pow(self, 2)), op.Sub(1, op.Erf(self)))
Expand All @@ -131,7 +131,7 @@ def aten_special_expit(self: TensorType) -> TensorType:


@torch_op(("aten::expm1", "aten::special_expm1"))
def aten_special_expm1(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_expm1(self: TFloat) -> TFloat:
"""special_expm1(Tensor self) -> Tensor"""

return op.Sub(op.Exp(self), 1)
Expand Down Expand Up @@ -216,9 +216,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:


@torch_op(("aten::log_softmax.int", "aten::special_log_softmax"), trace_only=True)
def aten_special_log_softmax(
self: TFloatOrBFloat16, dim: int, dtype: int = -1
) -> TFloatOrBFloat16:
def aten_special_log_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat:
"""special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
Expand Down Expand Up @@ -366,7 +364,7 @@ def aten_special_xlog1py(self: TensorType, other: TensorType) -> TensorType:


@torch_op(("aten::xlogy.Tensor", "aten::xlogy.Scalar_Self", "aten::xlogy.Scalar_Other"))
def aten_special_xlogy(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
def aten_special_xlogy(self: TFloat, other: TFloat) -> TFloat:
"""special_xlogy(Tensor self, Tensor other) -> Tensor"""

# https://pytorch.org/docs/stable/special.html#torch.special.xlogy
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/function_libs/torch_lib/tensor_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
INT64,
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
Expand All @@ -61,7 +61,6 @@
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
TFloatOrUInt8 = TypeVar("TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8])
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
Expand Down
Loading