-
Notifications
You must be signed in to change notification settings - Fork 139
Open
Labels
beginner friendlygraph rewritinghelp wantedExtra attention is neededExtra attention is neededlinalgLinear algebraLinear algebra
Description
Description
This comes up in statespace models. BlockDiag
is variadic, so graphs like this can be canonicalized:
import pytensor.tensor as pt
import pytensor
a, b, c = pt.matrices('a', 'b', 'c')
x = pt.linalg.block_diag(pt.linalg.block_diag(a, b), c)
fn = pytensor.function([a, b, c], x)
fn.dprint()
Current output:
BlockDiagonal{n_inputs=2} [id A] 1
├─ BlockDiagonal{n_inputs=2} [id B] 0
│ ├─ a [id C]
│ └─ b [id D]
└─ c [id E]
Desired output:
BlockDiagonal{n_inputs=3} [id A] 0
├─ a [id B]
├─ b [id C]
└─ c [id D]
The only tricky wrinkle I can see is that there might be ExpandDims
Ops sandwiched in between each "level" of BlockDiagonal
. We should either push these inside or pull them out.
Metadata
Metadata
Assignees
Labels
beginner friendlygraph rewritinghelp wantedExtra attention is neededExtra attention is neededlinalgLinear algebraLinear algebra