Skip to content

Add rewrite to fuse nested BlockDiag Ops #1593

@jessegrabowski

Description

@jessegrabowski

Description

This comes up in statespace models. BlockDiag is variadic, so graphs like this can be canonicalized:

import pytensor.tensor as pt
import pytensor
a, b, c = pt.matrices('a', 'b', 'c')
x = pt.linalg.block_diag(pt.linalg.block_diag(a, b), c)
fn = pytensor.function([a, b, c], x)
fn.dprint()

Current output:

BlockDiagonal{n_inputs=2} [id A] 1
 ├─ BlockDiagonal{n_inputs=2} [id B] 0
 │  ├─ a [id C]
 │  └─ b [id D]
 └─ c [id E]

Desired output:

BlockDiagonal{n_inputs=3} [id A] 0
 ├─ a [id B]
 ├─ b [id C]
 └─ c [id D]

The only tricky wrinkle I can see is that there might be ExpandDims Ops sandwiched in between each "level" of BlockDiagonal. We should either push these inside or pull them out.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions