Skip to content

Commit 5ccb0f3

Browse files
authored
Fix cast inference bug (#1884)
Fix bug introduced while migrating cast-logic to new IR in optimizer. --------- Signed-off-by: gramalingam <[email protected]>
1 parent 0283167 commit 5ccb0f3

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

onnxscript/ir/_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,10 @@ def __init__(
879879
)
880880
self._frozen: bool = frozen
881881

882+
def copy(self):
883+
"""Return a copy of the shape."""
884+
return Shape(self._dims, self._denotations, self._frozen)
885+
882886
@property
883887
def dims(self) -> tuple[int | SymbolicDim, ...]:
884888
"""All dimensions in the shape.

onnxscript/optimizer/_constant_folding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,13 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
236236
input = _get_input(node, 0)
237237
output = _get_output(node, 0)
238238
if input is not None and output is not None:
239-
_update_type(output, input.type)
239+
input_shape = input.shape
240+
if input_shape is not None:
241+
output.shape = input_shape.copy()
242+
if output is not None:
243+
output_dtype = _get_int_attribute(node, "to", None)
244+
if output_dtype is not None:
245+
output.type = ir.TensorType(ir.DataType(output_dtype))
240246
return None
241247

242248

0 commit comments

Comments
 (0)