diff --git a/pytensor/link/jax/dispatch/subtensor.py b/pytensor/link/jax/dispatch/subtensor.py index 1c659be29b..cd8f78575a 100644 --- a/pytensor/link/jax/dispatch/subtensor.py +++ b/pytensor/link/jax/dispatch/subtensor.py @@ -77,6 +77,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): @jax_funcify.register(AdvancedIncSubtensor) def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + if getattr(op, "set_instead_of_inc", False): def jax_fn(x, indices, y): @@ -87,8 +89,11 @@ def jax_fn(x, indices, y): def jax_fn(x, indices, y): return x.at[indices].add(y) - def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): - return jax_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + return jax_fn(x, indices, y) return advancedincsubtensor diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index fe0eda153e..4848b2cb25 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -107,28 +107,30 @@ def {function_name}({", ".join(input_names)}): @numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(op, AdvancedSubtensor): - x, y, idxs = node.inputs[0], None, node.inputs[1:] + x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:] else: - x, y, *idxs = node.inputs - - basic_idxs = [ - idx - for idx in idxs - if ( - isinstance(idx.type, NoneTypeT) - or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) - ) - ] - adv_idxs = [ - { - "axis": i, - "dtype": idx.type.dtype, - "bcast": idx.type.broadcastable, - "ndim": idx.type.ndim, - } - for i, idx in enumerate(idxs) - if isinstance(idx.type, TensorType) - ] + x, y, *tensor_inputs = node.inputs + + # Reconstruct indexing information from idx_list and tensor inputs + basic_idxs = [] + adv_idxs = [] + input_idx = 0 + + for i, entry in enumerate(op.idx_list): + if isinstance(entry, slice): + # Basic slice index + basic_idxs.append(entry) + elif isinstance(entry, Type): + # Advanced tensor index + if input_idx < len(tensor_inputs): + idx_input = tensor_inputs[input_idx] + adv_idxs.append({ + "axis": i, + "dtype": idx_input.type.dtype, + "bcast": idx_input.type.broadcastable, + "ndim": idx_input.type.ndim, + }) + input_idx += 1 # Special implementation for consecutive integer vector indices if ( diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 5dfa7dfa36..7e96a816c2 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -63,7 +63,10 @@ def makeslice(start, stop, step): @pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor) def pytorch_funcify_AdvSubtensor(op, node, **kwargs): - def advsubtensor(x, *indices): + idx_list = getattr(op, "idx_list", None) + + def advsubtensor(x, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) return x[indices] @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices): @pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor1) def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) inplace = op.inplace ignore_duplicates = getattr(op, "ignore_duplicates", False) if op.set_instead_of_inc: - def adv_set_subtensor(x, y, *indices): + def adv_set_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices): elif ignore_duplicates: - def adv_inc_subtensor_no_duplicates(x, y, *indices): + def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -132,13 +138,16 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices): return adv_inc_subtensor_no_duplicates else: - if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): + # Check if we have slice indexing in idx_list + has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False + if has_slice_indexing: raise NotImplementedError( "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" ) - def adv_inc_subtensor(x, y, *indices): - # Not needed because slices aren't supported + def adv_inc_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + # Not needed because slices aren't supported in this path # check_negative_steps(indices) if not inplace: x = x.clone() diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fbe97b9a68..599e3497d3 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -228,7 +228,18 @@ def local_replace_AdvancedSubtensor(fgraph, node): return indexed_var = node.inputs[0] - indices = node.inputs[1:] + tensor_inputs = node.inputs[1:] + + # Reconstruct indices from idx_list and tensor inputs + indices = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 axis = get_advsubtensor_axis(indices) @@ -255,7 +266,18 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): res = node.inputs[0] val = node.inputs[1] - indices = node.inputs[2:] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + indices = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 axis = get_advsubtensor_axis(indices) @@ -1751,9 +1773,22 @@ def ravel_multidimensional_bool_idx(fgraph, node): x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) """ if isinstance(node.op, AdvancedSubtensor): - x, *idxs = node.inputs + x = node.inputs[0] + tensor_inputs = node.inputs[1:] else: - x, y, *idxs = node.inputs + x, y = node.inputs[0], node.inputs[1] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + idxs = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + idxs.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + idxs.append(tensor_inputs[input_idx]) + input_idx += 1 if any( ( @@ -1791,12 +1826,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): new_idxs[bool_idx_pos] = raveled_bool_idx if isinstance(node.op, AdvancedSubtensor): - new_out = node.op(raveled_x, *new_idxs) + # Create new AdvancedSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the idx_list and tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + + new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) else: + # Create new AdvancedIncSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + # The dimensions of y that correspond to the boolean indices # must already be raveled in the original graph, so we don't need to do anything to it - new_out = node.op(raveled_x, y, *new_idxs) - # But we must reshape the output to math the original shape + new_out = AdvancedIncSubtensor( + new_idx_list, + inplace=node.op.inplace, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=node.op.ignore_duplicates + )(raveled_x, y, *new_tensor_inputs) + # But we must reshape the output to match the original shape new_out = new_out.reshape(x_shape) return [copy_stack_trace(node.outputs[0], new_out)] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 15e02265f1..c32fecb841 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2796,78 +2796,124 @@ def check_advanced_indexing_dimensions(input, idx_list): class AdvancedSubtensor(Op): """Implements NumPy's advanced indexing.""" - __props__ = () + __props__ = ("idx_list",) - def make_node(self, x, *indices): + def __init__(self, idx_list): + """ + Initialize AdvancedSubtensor with index list. + + Parameters + ---------- + idx_list : tuple + List of indices where slices are stored as-is, + and numerical indices are replaced by their types. + """ + self.idx_list = tuple(map(index_vars_to_types, idx_list)) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) + + def make_node(self, x, *inputs): + """ + Parameters + ---------- + x + The tensor to take a subtensor of. + inputs + A list of pytensor Scalars and Tensors (numerical indices only). + + """ x = as_tensor_variable(x) - indices = tuple(map(as_index_variable, indices)) + inputs = tuple(as_tensor_variable(a) for a in inputs) + idx_list = list(self.idx_list) + if len(idx_list) > x.type.ndim: + raise IndexError("too many indices for array") + + # Validate input count matches expected from idx_list + if len(inputs) != self.expected_inputs_len: + raise ValueError(f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}") + + # Build explicit_indices for shape inference explicit_indices = [] - new_axes = [] - for idx in indices: - if isinstance(idx.type, TensorType) and idx.dtype == "bool": - if idx.type.ndim == 0: - raise NotImplementedError( - "Indexing with scalar booleans not supported" - ) + input_idx = 0 + + for i, entry in enumerate(idx_list): + if isinstance(entry, slice): + # Reconstruct slice with actual values from inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + explicit_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index + inp = inputs[input_idx] + input_idx += 1 + + # Handle boolean indices + if inp.dtype == "bool": + if inp.type.ndim == 0: + raise NotImplementedError( + "Indexing with scalar booleans not supported" + ) - # Check static shape aligned - axis = len(explicit_indices) - len(new_axes) - indexed_shape = x.type.shape[axis : axis + idx.type.ndim] - for j, (indexed_length, indexer_length) in enumerate( - zip(indexed_shape, idx.type.shape) - ): - if ( - indexed_length is not None - and indexer_length is not None - and indexed_length != indexer_length + # Check static shape aligned + axis = len(explicit_indices) + indexed_shape = x.type.shape[axis : axis + inp.type.ndim] + for j, (indexed_length, indexer_length) in enumerate( + zip(indexed_shape, inp.type.shape) ): - raise IndexError( - f"boolean index did not match indexed tensor along axis {axis + j};" - f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" - ) - # Convert boolean indices to integer with nonzero, to reason about static shape next - if isinstance(idx, Constant): - nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] + if ( + indexed_length is not None + and indexer_length is not None + and indexed_length != indexer_length + ): + raise IndexError( + f"boolean index did not match indexed tensor along axis {axis + j};" + f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" + ) + # Convert boolean indices to integer with nonzero + if isinstance(inp, Constant): + nonzero_indices = [tensor_constant(i) for i in inp.data.nonzero()] + else: + nonzero_indices = inp.nonzero() + explicit_indices.extend(nonzero_indices) else: - # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero - # and seeing that other integer indices cannot possible match it - nonzero_indices = idx.nonzero() - explicit_indices.extend(nonzero_indices) + # Regular numerical index + explicit_indices.append(inp) else: - if isinstance(idx.type, NoneTypeT): - new_axes.append(len(explicit_indices)) - explicit_indices.append(idx) + raise ValueError(f"Invalid entry in idx_list: {entry}") - if (len(explicit_indices) - len(new_axes)) > x.type.ndim: + if len(explicit_indices) > x.type.ndim: raise IndexError( - f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" + f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed" ) - # Perform basic and advanced indexing shape inference separately + # Perform basic and advanced indexing shape inference separately (no newaxis) basic_group_shape = [] advanced_indices = [] adv_group_axis = None last_adv_group_axis = None - expanded_x_shape = tuple( - np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) - ) for i, (idx, dim_length) in enumerate( - zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) + zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None)) ): - if isinstance(idx.type, NoneTypeT): - basic_group_shape.append(1) # New-axis - elif isinstance(idx.type, SliceType): - if isinstance(idx, Constant): - basic_group_shape.append(slice_static_length(idx.data, dim_length)) - elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice): - basic_group_shape.append( - slice_static_length(slice(*idx.owner.inputs), dim_length) - ) - else: - # Symbolic root slice (owner is None), or slice operation we don't understand - basic_group_shape.append(None) - else: # TensorType + if isinstance(idx, slice): + basic_group_shape.append(slice_static_length(idx, dim_length)) + else: # TensorType (advanced index) # Keep track of advanced group axis if adv_group_axis is None: # First time we see an advanced index @@ -2902,7 +2948,7 @@ def make_node(self, x, *indices): return Apply( self, - [x, *indices], + [x, *inputs], [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], ) @@ -2918,19 +2964,57 @@ def is_bool_index(idx): or getattr(idx, "dtype", None) == "bool" ) - indices = node.inputs[1:] + # Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__) + inputs = node.inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(inputs): + full_indices.append(inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + index_shapes = [] - for idx, ishape in zip(indices, ishapes[1:], strict=True): - # Mixed bool indexes are converted to nonzero entries - shape0_op = Shape_i(0) - if is_bool_index(idx): - index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) - # The `ishapes` entries for `SliceType`s will be None, and - # we need to give `indexed_result_shape` the actual slices. - elif isinstance(getattr(idx, "type", None), SliceType): + for idx in full_indices: + if isinstance(idx, slice): index_shapes.append(idx) + elif hasattr(idx, 'type'): + # Mixed bool indexes are converted to nonzero entries + shape0_op = Shape_i(0) + if is_bool_index(idx): + index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) + else: + # Get ishape for this input + input_shape_idx = inputs.index(idx) + 1 # +1 because ishapes[0] is x + index_shapes.append(ishapes[input_shape_idx]) else: - index_shapes.append(ishape) + index_shapes.append(idx) res_shape = list( indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) @@ -2960,14 +3044,54 @@ def is_bool_index(idx): def perform(self, node, inputs, out_): (out,) = out_ - check_advanced_indexing_dimensions(inputs[0], inputs[1:]) - rval = inputs[0].__getitem__(tuple(inputs[1:])) + + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) + x = inputs[0] + tensor_inputs = inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) + rval = x.__getitem__(tuple(full_indices)) # When there are no arrays, we are not actually doing advanced # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value - if not any( - isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] - ): + has_tensor_indices = any( + isinstance(entry, Type) and not getattr(entry, 'broadcastable', (False,))[0] + for entry in self.idx_list + ) + if not has_tensor_indices: rval = rval.copy() out[0] = rval @@ -3005,7 +3129,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -3020,11 +3144,27 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) + op = node.op + tensor_inputs = node.inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + + return _non_consecutive_adv_indexing(full_indices) -advanced_subtensor = AdvancedSubtensor() +# Note: This is now a factory function since AdvancedSubtensor needs idx_list +# The old global instance approach won't work anymore @_vectorize_node.register(AdvancedSubtensor) @@ -3044,30 +3184,27 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): # which would put the indexed results to the left of the batch dimensions! # TODO: Not all cases must be handled by Blockwise, but the logic is complex - # Blockwise doesn't accept None or Slices types so we raise informative error here - # TODO: Implement these internally, so Blockwise is always a safe fallback - if any(not isinstance(idx, TensorVariable) for idx in idxs): - raise NotImplementedError( - "Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing " - "and slices or newaxis is currently not supported." - ) - else: - return vectorize_node_fallback(op, node, batch_x, *batch_idxs) + # With the new interface, all inputs are tensors, so Blockwise can handle them + return vectorize_node_fallback(op, node, batch_x, *batch_idxs) # Otherwise we just need to add None slices for every new batch dim x_batch_ndim = batch_x.type.ndim - x.type.ndim empty_slices = (slice(None),) * x_batch_ndim - return op.make_node(batch_x, *empty_slices, *batch_idxs) + new_idx_list = empty_slices + op.idx_list + return AdvancedSubtensor(new_idx_list).make_node(batch_x, *batch_idxs) class AdvancedIncSubtensor(Op): """Increments a subtensor using advanced indexing.""" - __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates") + __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates", "idx_list") def __init__( - self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False + self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False ): + self.idx_list = tuple(map(index_vars_to_types, idx_list)) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) self.set_instead_of_inc = set_instead_of_inc self.inplace = inplace if inplace: @@ -3085,6 +3222,10 @@ def make_node(self, x, y, *inputs): x = as_tensor_variable(x) y = as_tensor_variable(y) + # Validate that we have the right number of tensor inputs for our idx_list + if len(inputs) != self.expected_inputs_len: + raise ValueError(f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}") + new_inputs = [] for inp in inputs: if isinstance(inp, list | tuple): @@ -3097,9 +3238,43 @@ def make_node(self, x, y, *inputs): ) def perform(self, node, inputs, out_): - x, y, *indices = inputs + x, y, *tensor_inputs = inputs - check_advanced_indexing_dimensions(x, indices) + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) (out,) = out_ if not self.inplace: @@ -3108,11 +3283,11 @@ def perform(self, node, inputs, out_): out[0] = x if self.set_instead_of_inc: - out[0][tuple(indices)] = y + out[0][tuple(full_indices)] = y elif self.ignore_duplicates: - out[0][tuple(indices)] += y + out[0][tuple(full_indices)] += y else: - np.add.at(out[0], tuple(indices), y) + np.add.at(out[0], tuple(full_indices), y) def infer_shape(self, fgraph, node, ishapes): return [ishapes[0]] @@ -3142,10 +3317,12 @@ def grad(self, inpt, output_gradients): raise NotImplementedError("No support for complex grad yet") else: if self.set_instead_of_inc: - gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs) + gx = AdvancedIncSubtensor(self.idx_list, set_instead_of_inc=True).make_node( + outgrad, y.zeros_like(), *idxs + ).outputs[0] else: gx = outgrad - gy = advanced_subtensor(outgrad, *idxs) + gy = AdvancedSubtensor(self.idx_list).make_node(outgrad, *idxs).outputs[0] # Make sure to sum gy over the dimensions of y that have been # added or broadcasted gy = _sum_grad_over_bcasted_dims(y, gy) @@ -3165,7 +3342,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -3180,16 +3357,127 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) + op = node.op + tensor_inputs = node.inputs[2:] # Skip x and y + + full_indices = [] + input_idx = 0 + + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + + return _non_consecutive_adv_indexing(full_indices) + + +def advanced_subtensor(x, *args): + """Create an AdvancedSubtensor operation. + + This function converts the arguments to work with the new AdvancedSubtensor + interface that separates slice structure from variable inputs. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. + """ + # Convert args using as_index_variable (like original AdvancedSubtensor did) + processed_args = tuple(map(as_index_variable, args)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure + if isinstance(arg, Constant): + # Constant slice + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Variable slice - extract components + start, stop, step = arg.owner.inputs + + # Convert components to types for idx_list + start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None + stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None + step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) + else: + # Other slice case + idx_list.append(slice(None)) + else: + # Tensor index (should not be NoneType since newaxis handled in __getitem__) + idx_list.append(index_vars_to_types(arg)) + input_vars.append(arg) + + return AdvancedSubtensor(idx_list).make_node(x, *input_vars).outputs[0] + + +def advanced_inc_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for incrementing. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. + """ + # Convert args using as_index_variable (like original AdvancedIncSubtensor would) + processed_args = tuple(map(as_index_variable, args)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure + if isinstance(arg, Constant): + # Constant slice + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Variable slice - extract components + start, stop, step = arg.owner.inputs + + # Convert components to types for idx_list + start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None + stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None + step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) + else: + # Other slice case + idx_list.append(slice(None)) + else: + # Tensor index (should not be NoneType since newaxis handled in __getitem__) + idx_list.append(index_vars_to_types(arg)) + input_vars.append(arg) + + return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *input_vars).outputs[0] -advanced_inc_subtensor = AdvancedIncSubtensor() -advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) -advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True) -advanced_set_subtensor_nodup = AdvancedIncSubtensor( - set_instead_of_inc=True, ignore_duplicates=True -) +def advanced_set_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for setting.""" + return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs) def take(a, indices, axis=None, mode="raise"): diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 474d08c49d..d8cabe737f 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -555,55 +555,55 @@ def is_empty_array(val): else: advanced = True - if advanced: - return pt.subtensor.advanced_subtensor(self, *args) - else: - if np.newaxis in args or NoneConst in args: - # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new - # broadcastable dimension at this location". Since PyTensor adds - # new broadcastable dimensions via the `DimShuffle` `Op`, the - # following code uses said `Op` to add one of the new axes and - # then uses recursion to apply any other indices and add any - # remaining new axes. - - counter = 0 - pattern = [] - new_args = [] - for arg in args: - if arg is np.newaxis or arg is NoneConst: - pattern.append("x") - new_args.append(slice(None, None, None)) - else: - pattern.append(counter) - counter += 1 - new_args.append(arg) - - pattern.extend(list(range(counter, self.ndim))) - - view = self.dimshuffle(pattern) - full_slices = True - for arg in new_args: - # We can't do arg == slice(None, None, None) as in - # Python 2.7, this call __lt__ if we have a slice - # with some symbolic variable. - if not ( - isinstance(arg, slice) - and (arg.start is None or arg.start is NoneConst) - and (arg.stop is None or arg.stop is NoneConst) - and (arg.step is None or arg.step is NoneConst) - ): - full_slices = False - if full_slices: - return view + # Handle newaxis (None) for both basic and advanced indexing + if np.newaxis in args or NoneConst in args: + # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new + # broadcastable dimension at this location". Since PyTensor adds + # new broadcastable dimensions via the `DimShuffle` `Op`, the + # following code uses said `Op` to add one of the new axes and + # then uses recursion to apply any other indices and add any + # remaining new axes. + + counter = 0 + pattern = [] + new_args = [] + for arg in args: + if arg is np.newaxis or arg is NoneConst: + pattern.append("x") + new_args.append(slice(None, None, None)) else: - return view.__getitem__(tuple(new_args)) + pattern.append(counter) + counter += 1 + new_args.append(arg) + + pattern.extend(list(range(counter, self.ndim))) + + view = self.dimshuffle(pattern) + full_slices = True + for arg in new_args: + # We can't do arg == slice(None, None, None) as in + # Python 2.7, this call __lt__ if we have a slice + # with some symbolic variable. + if not ( + isinstance(arg, slice) + and (arg.start is None or arg.start is NoneConst) + and (arg.stop is None or arg.stop is NoneConst) + and (arg.step is None or arg.step is NoneConst) + ): + full_slices = False + if full_slices: + return view else: - return pt.subtensor.Subtensor(args)( - self, - *pt.subtensor.get_slice_elements( - args, lambda entry: isinstance(entry, Variable) - ), - ) + return view.__getitem__(tuple(new_args)) + elif advanced: + return pt.subtensor.advanced_subtensor(self, *args) + else: + return pt.subtensor.Subtensor(args)( + self, + *pt.subtensor.get_slice_elements( + args, lambda entry: isinstance(entry, Variable) + ), + ) def __setitem__(self, key, value): raise TypeError(