diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1f21a1afe8d6..382fc6fce37d 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -519,9 +519,6 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); - // use lhs's element type as compute type - rhs = - hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType()); rhsTy = dyn_cast(rhs.getType()); } @@ -538,16 +535,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (isa(lhsElemTy) && isa(rhsElemTy)) { lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); - } else if (isa(lhsElemTy) && - isa(rhsElemTy)) { - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - if (lhsElemTy.getIntOrFloatBitWidth() > - rhsElemTy.getIntOrFloatBitWidth()) { - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); - } else { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); - } + // use lhs's element type as compute type + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0430ba9d5a47..c7cbe560b3e6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -536,7 +536,6 @@ "AtenPolarFloatModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",