Skip to content

Conversation

penguin-wwy
Copy link
Collaborator

No description provided.

lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
// use lhs's element type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems that this will cause torch.tensor([1.0]) < torch.tensor([2.0],dtype=torch.double) will compute on fp32.
I think we need to discuss how to use stablehlo to describe torch's default compute type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see another PR:#3673

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code, the comparison was done based on the number of bits

if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
}

And then, #3518 promotes rhsType to lhsType in advance. I'm not quite sure about its purpose; this PR just ensures that when lhsType.isInterger && rhsType.isFloat, the comparison is done in a float manner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. In torch, tensor is first level, scalar is second level. When compute tensor < scalar, mostly should use tensor's dtype as compute type except torch.tensor([1]) < 1.1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to PyTorch's implementation, when comparing a tensor and a scalar, the scalar is first converted to a tensor using the scalar_to_tensor function before the comparison. Therefore, compare tensor scalar should exhibit the same behavior as compare tensor tensor.
https://github.com/pytorch/pytorch/blob/02169364e15932d886370d711482ef1cd5a5b137/aten/src/ATen/ScalarOps.h#L45-L51
Therefore, the correct semantics should be to maintain the approach used before #3518 .

qingyunqu added a commit that referenced this pull request Sep 13, 2024
qingyunqu added a commit that referenced this pull request Sep 13, 2024
@penguin-wwy penguin-wwy deleted the fix_stablehlo branch September 14, 2024 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants