-
Notifications
You must be signed in to change notification settings - Fork 604
[Torch] Fold aten.to.dtype
on splat constants.
#4306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Not sure who can review, maybe you would know @vivekkhandelwal1 @zjgarvey ? |
3bf4e4b
to
1d7b55b
Compare
9b8168c
to
42edabd
Compare
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions when the operand is a splat `DenseElementsAttr`. Folding is done according to torch's rounding behavior, i.e. * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true. * Float → Int: round toward zero. * Int → Float: sign-aware, rmNearestTiesToEven. * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`. * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness. Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
42edabd
to
3abbd48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the folder improvements! Sorry for the long turnaround.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this one slipped through the cracks. LGTM.
One suggestion for future is to not amend the git commit for changes that has been reviewed already when addressing feedback, so that it's easy for reviewers to only review the changes since the last feedback was provided. Thanks!
This commit teaches
AtenToDtypeOp::fold
to constant-fold dtype conversions when the operand is a splatDenseElementsAttr
.Folding is done according to torch's rounding behavior, i.e.
mlir::FloatType::getFloatSemantics()
.zextOrTrunc
/sextOrTrunc
based on source signedness.Folding is only performed when
non_blocking == false
,copy == false
, andmemory_format
is None.