2424    ScalarFromTensor ,
2525    TensorFromScalar ,
2626    alloc ,
27+     arange ,
2728    cast ,
2829    concatenate ,
2930    expand_dims ,
3435    switch ,
3536)
3637from  pytensor .tensor .basic  import  constant  as  tensor_constant 
37- from  pytensor .tensor .blockwise  import  Blockwise 
38+ from  pytensor .tensor .blockwise  import  Blockwise ,  _squeeze_left 
3839from  pytensor .tensor .elemwise  import  Elemwise 
3940from  pytensor .tensor .exceptions  import  NotScalarConstantError 
41+ from  pytensor .tensor .extra_ops  import  broadcast_to 
4042from  pytensor .tensor .math  import  (
4143    add ,
4244    and_ ,
5860)
5961from  pytensor .tensor .shape  import  (
6062    shape_padleft ,
63+     shape_padright ,
6164    shape_tuple ,
6265)
6366from  pytensor .tensor .sharedvar  import  TensorSharedVariable 
@@ -1580,6 +1583,9 @@ def local_blockwise_of_subtensor(fgraph, node):
15801583    """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. 
15811584
15821585    Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none 
1586+ 
1587+     TODO: Handle batched indices like we do with blockwise of inc_subtensor 
1588+     TODO: Extend to AdvanceSubtensor 
15831589    """ 
15841590    if  not  isinstance (node .op .core_op , Subtensor ):
15851591        return 
@@ -1600,64 +1606,151 @@ def local_blockwise_of_subtensor(fgraph, node):
16001606@register_stabilize ("shape_unsafe" ) 
16011607@register_specialize ("shape_unsafe" ) 
16021608@node_rewriter ([Blockwise ]) 
1603- def  local_blockwise_advanced_inc_subtensor (fgraph , node ):
1604-     """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" 
1605-     if  not  isinstance (node .op .core_op , AdvancedIncSubtensor ):
1606-         return  None 
1609+ def  local_blockwise_inc_subtensor (fgraph , node ):
1610+     """Rewrite blockwised inc_subtensors. 
16071611
1608-     x , y , * idxs  =  node .inputs 
1612+     Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch 
1613+     Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites 
16091614
1610-     # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case 
1611-     if  any (
1612-         (
1613-             isinstance (idx , SliceType  |  NoneTypeT )
1614-             or  (idx .type .dtype  ==  "bool"  and  idx .type .ndim  >  0 )
1615-         )
1616-         for  idx  in  idxs 
1617-     ):
1615+     such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y), 
1616+     and can be safely rewritten without Blockwise. 
1617+     """ 
1618+     core_op  =  node .op .core_op 
1619+     if  not  isinstance (core_op , AdvancedIncSubtensor  |  IncSubtensor ):
16181620        return  None 
16191621
1620-     op : Blockwise  =  node .op   # type: ignore 
1621-     batch_ndim  =  op .batch_ndim (node )
1622- 
1623-     new_idxs  =  []
1624-     for  idx  in  idxs :
1625-         if  all (idx .type .broadcastable [:batch_ndim ]):
1626-             new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1627-         else :
1628-             # Rewrite does not apply 
1622+     x , y , * idxs  =  node .inputs 
1623+     [out ] =  node .outputs 
1624+     if  isinstance (node .op .core_op , AdvancedIncSubtensor ):
1625+         if  any (
1626+             (
1627+                 # Blockwise requires all inputs to be tensors so it is not possible 
1628+                 # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case 
1629+                 # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices 
1630+                 # are separated by basic indices 
1631+                 isinstance (idx , SliceType  |  NoneTypeT )
1632+                 # Also get out if we have boolean indices as they cross dimension boundaries 
1633+                 # / can't be safely broadcasted depending on their runtime content 
1634+                 or  (idx .type .dtype  ==  "bool" )
1635+             )
1636+             for  idx  in  idxs 
1637+         ):
16291638            return  None 
16301639
1631-     x_batch_bcast  =  x .type .broadcastable [:batch_ndim ]
1632-     y_batch_bcast  =  y .type .broadcastable [:batch_ndim ]
1633-     if  any (xb  and  not  yb  for  xb , yb  in  zip (x_batch_bcast , y_batch_bcast , strict = True )):
1634-         # Need to broadcast batch x dims 
1635-         batch_shape  =  tuple (
1636-             x_dim  if  (not  xb  or  yb ) else  y_dim 
1637-             for  xb , x_dim , yb , y_dim  in  zip (
1638-                 x_batch_bcast ,
1640+     batch_ndim  =  node .op .batch_ndim (node )
1641+     idxs_core_ndim  =  [len (inp_sig ) for  inp_sig  in  node .op .inputs_sig [2 :]]
1642+     max_idx_core_ndim  =  max (idxs_core_ndim , default = 0 )
1643+ 
1644+     # Step 1. Broadcast buffer to batch_shape 
1645+     if  x .type .broadcastable  !=  out .type .broadcastable :
1646+         batch_shape  =  [1 ] *  batch_ndim 
1647+         for  inp  in  node .inputs :
1648+             for  i , (broadcastable , batch_dim ) in  enumerate (
1649+                 zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1650+             ):
1651+                 if  broadcastable :
1652+                     # This dimension is broadcastable, it doesn't provide shape information 
1653+                     continue 
1654+                 if  batch_shape [i ] !=  1 :
1655+                     # We already found a source of shape for this batch dimension 
1656+                     continue 
1657+                 batch_shape [i ] =  batch_dim 
1658+         x  =  broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1659+         assert  x .type .broadcastable  ==  out .type .broadcastable 
1660+ 
1661+     # Step 2. Massage indices so they respect blockwise semantics 
1662+     if  isinstance (core_op , IncSubtensor ):
1663+         # For basic IncSubtensor there are two cases: 
1664+         # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice 
1665+         # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions 
1666+         #   in case we can end up with a basic IncSubtensor again 
1667+         core_idxs  =  []
1668+         counter  =  0 
1669+         for  idx  in  core_op .idx_list :
1670+             if  isinstance (idx , slice ):
1671+                 # Squeeze away dummy dimensions so we can convert to slice 
1672+                 new_entries  =  [None , None , None ]
1673+                 for  i , entry  in  enumerate ((idx .start , idx .stop , idx .step )):
1674+                     if  entry  is  None :
1675+                         continue 
1676+                     else :
1677+                         new_entries [i ] =  new_entry  =  idxs [counter ].squeeze ()
1678+                         counter  +=  1 
1679+                         if  new_entry .ndim  >  0 :
1680+                             # If the slice entry has dimensions after the squeeze we can't convert it to a slice 
1681+                             # We could try to convert to equivalent integer indices, but nothing guarantees 
1682+                             # that the slice is "square". 
1683+                             return  None 
1684+                 core_idxs .append (slice (* new_entries ))
1685+             else :
1686+                 core_idxs .append (_squeeze_left (idxs [counter ]))
1687+                 counter  +=  1 
1688+     else :
1689+         # For AdvancedIncSubtensor we have tensor integer indices, 
1690+         # We need to expand batch indexes on the right, so they don't interact with core index dimensions 
1691+         # We still squeeze on the left in case that allows us to use simpler indices 
1692+         core_idxs  =  [
1693+             _squeeze_left (
1694+                 shape_padright (idx , max_idx_core_ndim  -  idx_core_ndim ),
1695+                 stop_at_dim = batch_ndim ,
1696+             )
1697+             for  idx , idx_core_ndim  in  zip (idxs , idxs_core_ndim )
1698+         ]
1699+ 
1700+     # Step 3. Create new indices for the new batch dimension of x 
1701+     if  not  all (
1702+         all (idx .type .broadcastable [:batch_ndim ])
1703+         for  idx  in  idxs 
1704+         if  not  isinstance (idx , slice )
1705+     ):
1706+         # If indices have batch dimensions in the indices, they will interact with the new dimensions of x 
1707+         # We build vectorized indexing with new arange indices that do not interact with core indices or each other 
1708+         # (i.e., they broadcast) 
1709+ 
1710+         # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front), 
1711+         # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension, 
1712+         # even if not all batch dimensions have corresponding batch indices. 
1713+         batch_slices  =  [
1714+             shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1715+             for  (x_batch_shape , n ) in  zip (
16391716                tuple (x .shape )[:batch_ndim ],
1640-                 y_batch_bcast ,
1641-                 tuple (y .shape )[:batch_ndim ],
1642-                 strict = True ,
1717+                 reversed (range (max_idx_core_ndim , max_idx_core_ndim  +  batch_ndim )),
16431718            )
1644-         )
1645-         core_shape  =  tuple (x .shape )[batch_ndim :]
1646-         x  =  alloc (x , * batch_shape , * core_shape )
1647- 
1648-     new_idxs  =  [slice (None )] *  batch_ndim  +  new_idxs 
1649-     x_view  =  x [tuple (new_idxs )]
1650- 
1651-     # We need to introduce any implicit expand_dims on core dimension of y 
1652-     y_core_ndim  =  y .type .ndim  -  batch_ndim 
1653-     if  (missing_y_core_ndim  :=  x_view .type .ndim  -  batch_ndim  -  y_core_ndim ) >  0 :
1654-         missing_axes  =  tuple (range (batch_ndim , batch_ndim  +  missing_y_core_ndim ))
1655-         y  =  expand_dims (y , missing_axes )
1656- 
1657-     symbolic_idxs  =  x_view .owner .inputs [1 :]
1658-     new_out  =  op .core_op .make_node (x , y , * symbolic_idxs ).outputs 
1659-     copy_stack_trace (node .outputs , new_out )
1660-     return  new_out 
1719+         ]
1720+     else :
1721+         # In the case we don't have batch indices, 
1722+         # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y 
1723+         batch_slices  =  [slice (None )] *  batch_ndim 
1724+ 
1725+     new_idxs  =  (* batch_slices , * core_idxs )
1726+     x_view  =  x [new_idxs ]
1727+ 
1728+     # Step 4. Introduce any implicit expand_dims on core dimension of y 
1729+     missing_y_core_ndim  =  x_view .type .ndim  -  y .type .ndim 
1730+     implicit_axes  =  tuple (range (batch_ndim , batch_ndim  +  missing_y_core_ndim ))
1731+     y  =  _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1732+ 
1733+     if  isinstance (core_op , IncSubtensor ):
1734+         # Check if we can still use a basic IncSubtensor 
1735+         if  isinstance (x_view .owner .op , Subtensor ):
1736+             new_props  =  core_op ._props_dict ()
1737+             new_props ["idx_list" ] =  x_view .owner .op .idx_list 
1738+             new_core_op  =  type (core_op )(** new_props )
1739+             symbolic_idxs  =  x_view .owner .inputs [1 :]
1740+             new_out  =  new_core_op (x , y , * symbolic_idxs )
1741+         else :
1742+             # We need to use AdvancedSet/IncSubtensor 
1743+             if  core_op .set_instead_of_inc :
1744+                 new_out  =  x [new_idxs ].set (y )
1745+             else :
1746+                 new_out  =  x [new_idxs ].inc (y )
1747+     else :
1748+         # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op 
1749+         symbolic_idxs  =  x_view .owner .inputs [1 :]
1750+         new_out  =  core_op (x , y , * symbolic_idxs )
1751+ 
1752+     copy_stack_trace (out , new_out )
1753+     return  [new_out ]
16611754
16621755
16631756@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ]) 
0 commit comments