-
Notifications
You must be signed in to change notification settings - Fork 87
Implements aten_index_put if inputs are SymbolicTensor #2606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
3cb493f
c9472f3
434fcfb
5cf4882
ae6adca
8510364
e4d574a
86d482d
8305cac
02dda0e
e6f7633
e108dc3
988e9f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4383,6 +4383,11 @@ | |
| See implementation of `torch.onnx.symbolic_opset11.index_put | ||
| <https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_. | ||
| """ | ||
| if len(indices) > 1 and any( | ||
| isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor) | ||
|
||
| for indice in indices | ||
| ) and len(values.shape) == 1: | ||
| return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate) | ||
xadupre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _make_reshape_list_broadcastable(reshape_list, values_shape): | ||
| # Remove ones until the rank of reshape_list matches values_shape. | ||
|
|
@@ -4452,14 +4457,110 @@ | |
| # Flatten values to match the indices | ||
| flat_values = op.Reshape(values, [-1]) | ||
|
|
||
| if accumulate: | ||
| result = op.ScatterND(self, new_index, flat_values, reduction="add") | ||
| else: | ||
| result = op.ScatterND(self, new_index, flat_values) | ||
|
|
||
| scatter_kwargs = dict(reduction="add") if accumulate else {} | ||
| result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs) | ||
| return result | ||
|
|
||
|
|
||
| def _aten_index_put_dynamic( | ||
| x: TReal, | ||
| indices: Sequence[INT64], | ||
| values: TReal, | ||
| accumulate: bool = False, | ||
| ) -> TReal: | ||
| def _1dint(i: int): | ||
| return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i])) | ||
|
|
||
| def _0dint(i: int): | ||
| return op.Constant(value_int=ir.AttrInt64("value_int", i)) | ||
|
|
||
| def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int): | ||
| if ind is not None: | ||
| return op.Cast(ind, to=INT64.dtype), False | ||
| return ( | ||
| op.Cast( | ||
| op.Range( # Range does not return a typed result | ||
| _0dint(0), | ||
| op.Squeeze(op.Shape(x, start=dim, end=dim + 1)), | ||
| _0dint(1), | ||
| ), | ||
| to=INT64.dtype, | ||
| ), | ||
| True, | ||
| ) | ||
|
|
||
| rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices] | ||
|
||
| assert all(rk1s) and len(rk1s) == len(x.shape), ( | ||
| f"input_put not implemented for indices={indices}, " | ||
| f"where rk1s={rk1s}, rank(x)={len(x.shape)}" | ||
| ) | ||
xadupre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| shape_x = op.Shape(x) | ||
| exped = [] | ||
| fixed = [] | ||
| reshape_value_shape2 = [] | ||
| expand_value_shape = [] | ||
| for i, ind in enumerate(indices): | ||
| if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): | ||
|
||
| ind.dtype = ir.DataType.INT64 | ||
| ind, expanded = _make_range_or_cast(ind, shape_x, False, i) | ||
| if expanded: | ||
| exped.append((i, ind)) | ||
| expand_value_shape.append(op.Shape(x, start=i, end=i + 1)) | ||
| reshape_value_shape2.append(_1dint(1)) | ||
| else: | ||
| expand_value_shape.append(_1dint(1)) | ||
| reshape_value_shape2.append(op.Shape(ind)) | ||
| fixed.append((i, ind)) | ||
|
|
||
| reshape_value_shape1 = [_1dint(1)] * len(indices) | ||
| if len(fixed) <= 1: | ||
| reshape_value_shape1 = None | ||
| elif fixed: | ||
| reshape_value_shape1[fixed[-1][0]] = _1dint(-1) | ||
|
|
||
| def _mkstride(x, i): | ||
| if i >= len(x.shape) - 1: | ||
| return _1dint(1) | ||
| if i == len(x.shape) - 2: | ||
| return op.Shape(x, start=i + 1) | ||
| return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1) | ||
|
|
||
| shape = [1] * (len(x.shape) + 1) | ||
| mfixed = [] | ||
|
||
| if fixed: | ||
| new_shape = shape.copy() | ||
| new_shape[-1] = -1 | ||
| mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed] | ||
|
|
||
| mexped = [] | ||
| for i, e in exped: | ||
| new_shape = shape.copy() | ||
| new_shape[i] = -1 | ||
| mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape)) | ||
|
|
||
| # final sum | ||
| unflat = None | ||
| for a in [*mfixed, *mexped]: | ||
| if unflat is None: | ||
| unflat = a | ||
| continue | ||
| unflat = op.Add(unflat, a) | ||
|
|
||
| # value_shape | ||
| expanded_values = values | ||
| if reshape_value_shape1 is not None: | ||
| expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0)) | ||
| # Bug here: Error calling operator 'Concat' with args | ||
| # (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1]) | ||
| expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0)) | ||
| flat_ind = op.Reshape(unflat, _1dint(-1)) | ||
| expanded_values = op.Reshape(expanded_values, _1dint(-1)) | ||
| flat_x = op.Reshape(x, _1dint(-1)) | ||
| scat_kwargs = {"reduction": "add"} if accumulate else {} | ||
| flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs) | ||
| return op.Reshape(flat_up_x, op.Shape(x)) | ||
|
|
||
|
|
||
| @torch_op("aten::index_put", trace_only=True) | ||
| def aten_index_put_bool( | ||
| self: TReal, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.