Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,6 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
rhs.getType());
// use lhs's element type as compute type
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
}

Expand All @@ -538,16 +535,9 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {

if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
// use lhs's element type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems that this will cause torch.tensor([1.0]) < torch.tensor([2.0],dtype=torch.double) will compute on fp32.
I think we need to discuss how to use stablehlo to describe torch's default compute type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see another PR:#3673

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code, the comparison was done based on the number of bits

if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
}

And then, #3518 promotes rhsType to lhsType in advance. I'm not quite sure about its purpose; this PR just ensures that when lhsType.isInterger && rhsType.isFloat, the comparison is done in a float manner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. In torch, tensor is first level, scalar is second level. When compute tensor < scalar, mostly should use tensor's dtype as compute type except torch.tensor([1]) < 1.1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to PyTorch's implementation, when comparing a tensor and a scalar, the scalar is first converted to a tensor using the scalar_to_tensor function before the comparison. Therefore, compare tensor scalar should exhibit the same behavior as compare tensor tensor.
https://github.com/pytorch/pytorch/blob/02169364e15932d886370d711482ef1cd5a5b137/aten/src/ATen/ScalarOps.h#L45-L51
Therefore, the correct semantics should be to maintain the approach used before #3518 .

}
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();

Expand Down
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@
"AtenPolarFloatModule_basic",
"DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
Expand Down
Loading