-
Notifications
You must be signed in to change notification settings - Fork 139
Description
Description
OPs are supposed to be immutable/ hashable.
The hash of the Scan Op is defined based on the graph with which it is created originally.
However the equality of Scan depends on the state of the internal FunctionGraph which gets modified during compilation. This can lead to inconsistent Scan Ops that have different hashes but evaluate equal, as in the example below, where both cases simplify to the same during compilation.
import pytensor
import pytensor.tensor as pt
from pytensor.scan.op import Scan
from pytensor.compile.mode import get_default_mode
# Define Scan Ops with distinct initial hashes
x0 = pt.scalar("x0")
xs, _ = pytensor.scan(lambda x: x + 0, outputs_info=[x0], n_steps=5)
ys, _ = pytensor.scan(lambda x: x * 1, outputs_info=[x0], n_steps=5)
# During compilation the hashes remain the same, but the Op fgraph changes
fn = pytensor.function([x0], [xs, ys], mode=get_default_mode().excluding("scan"))
scan_op1, scan_op2 = [node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)]
assert scan_op1 == scan_op2 # As it should
assert hash(scan_op1) == hash(scan_op2) # Fails
This breaks Python contract, where objects that evaluate equal must hash equal (or not be hashable). This doesn't seem to be causing us many problems, but it's a potential footgun. We should reassess if the desire to not clone the fgraph while still allowing it to be optimized makes sense.
I suspect optimizing the inner graph should be an explicit rewrite that creates a new Op and clones the FunctionGraph. The FunctionGraph inside the Scan Op should be deemed frozen (we may even want to add a frozen FunctionGraph class to ensure this).
Note that all Scan rewrites create new Ops and don't mutate the original FunctionGraph, this is merely a compilation time concern.