Skip to content
111 changes: 106 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4383,6 +4383,11 @@
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if len(indices) > 1 and any(
isinstance(indice, torch.onnx._internal.exporter._tensors.SymbolicTensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check not isinstance(index, int) instead as we should not reference the private class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the possible types here? Only int and SymbolicTensor? I prefer to keep SymbolicTensor because I know exactly which type the function is supposed to handle.

Copy link
Collaborator

@justinchuby justinchuby Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only int and SymbolicTensor for the exporter, but could be other ir.Value subclasses as well. As the type is an internal type it is not meant for public use, the current usage is not supported and is brittle.

If preferred you may check for ir.Value instead, but really we just assume a type that has shape and dtype fields

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to suggest. I still prefer to keep this one since this is what shows up in the error message.

for indice in indices
) and len(values.shape) == 1:
return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand Down Expand Up @@ -4452,14 +4457,110 @@
# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, flat_values)

scatter_kwargs = dict(reduction="add") if accumulate else {}
result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs)
return result


def _aten_index_put_dynamic(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

def _0dint(i: int):
return op.Constant(value_int=ir.AttrInt64("value_int", i))

def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
if ind is not None:
return op.Cast(ind, to=INT64.dtype), False
return (
op.Cast(
op.Range( # Range does not return a typed result
_0dint(0),
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
_0dint(1),
),
to=INT64.dtype,
),
True,
)

rk1s = [(ind is None or len(ind.shape) == 1) for ind in indices]
assert all(rk1s) and len(rk1s) == len(x.shape), (
f"input_put not implemented for indices={indices}, "
f"where rk1s={rk1s}, rank(x)={len(x.shape)}"
)
shape_x = op.Shape(x)
exped = []
fixed = []
reshape_value_shape2 = []
expand_value_shape = []
for i, ind in enumerate(indices):
if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor):
ind.dtype = ir.DataType.INT64
ind, expanded = _make_range_or_cast(ind, shape_x, False, i)
if expanded:
exped.append((i, ind))
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
reshape_value_shape2.append(_1dint(1))
else:
expand_value_shape.append(_1dint(1))
reshape_value_shape2.append(op.Shape(ind))
fixed.append((i, ind))

reshape_value_shape1 = [_1dint(1)] * len(indices)
if len(fixed) <= 1:
reshape_value_shape1 = None
elif fixed:
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)

def _mkstride(x, i):
if i >= len(x.shape) - 1:
return _1dint(1)
if i == len(x.shape) - 2:
return op.Shape(x, start=i + 1)
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)

shape = [1] * (len(x.shape) + 1)
mfixed = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer clear variable names and avoid abbreviations.

Copy link
Member Author

@xadupre xadupre Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable are used just after. It is short to make the code shorter and esier to read.

Copy link
Collaborator

@justinchuby justinchuby Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any abbreviation adds cognitive load and is not preferred. All names should be thoughtfully created to be precise. Creating variables used right after is a plus but I don’t feel warranting the use of ambiguous names.

https://google.github.io/styleguide/pyguide.html#316-naming

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I let you choose the name.

if fixed:
new_shape = shape.copy()
new_shape[-1] = -1
mfixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed]

mexped = []
for i, e in exped:
new_shape = shape.copy()
new_shape[i] = -1
mexped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape))

# final sum
unflat = None
for a in [*mfixed, *mexped]:
if unflat is None:
unflat = a
continue
unflat = op.Add(unflat, a)

# value_shape
expanded_values = values
if reshape_value_shape1 is not None:
expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0))
# Bug here: Error calling operator 'Concat' with args
# (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1])
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
flat_ind = op.Reshape(unflat, _1dint(-1))
expanded_values = op.Reshape(expanded_values, _1dint(-1))
flat_x = op.Reshape(x, _1dint(-1))
scat_kwargs = {"reduction": "add"} if accumulate else {}
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
return op.Reshape(flat_up_x, op.Shape(x))


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
Expand Down
43 changes: 43 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,49 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_dynamic(self):
for dimension in [3, 4, 2]:
with self.subTest(dimension=dimension):

class Model(torch.nn.Module):
def __init__(self, dimension):
super().__init__()
self.params = torch.zeros(
(4, 5)
if dimension == 2
else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5))
)
self.dimension = dimension

def forward(self, update, index1, index2):
copy = self.params.clone()
if self.dimension == 2:
copy[index1, index2] = update
elif self.dimension == 3:
copy[:, index1, index2] = update
else:
copy[:, :, index1, index2] = update
return copy

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
input_names=["update", "index1", "index2"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes={
"update": {0: "dn"},
"index1": {0: "dn"},
"index2": {0: "dn"},
},
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading