Skip to content

Commit 3e9665c

Browse files
Hash-cons Apply, Constant and change node input replacement semantics
1 parent 3c665a5 commit 3e9665c

File tree

19 files changed

+1034
-765
lines changed

19 files changed

+1034
-765
lines changed

aesara/compile/debugmode.py

Lines changed: 57 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from itertools import chain
1515
from itertools import product as itertools_product
1616
from logging import Logger
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Optional, Union
1818
from warnings import warn
1919

2020
import numpy as np
21+
from typing_extensions import Literal
2122

2223
import aesara
2324
from aesara.compile.function.types import (
@@ -42,7 +43,9 @@
4243
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
4344

4445

45-
__docformat__ = "restructuredtext en"
46+
if TYPE_CHECKING:
47+
from aesara.graph.basic import Apply
48+
4649
_logger: Logger = logging.getLogger("aesara.compile.debugmode")
4750
_logger.addFilter(NoDuplicateOptWarningFilter())
4851

@@ -1109,43 +1112,32 @@ class _FunctionGraphEvent:
11091112
11101113
"""
11111114

1112-
kind = ""
1113-
"""
1114-
One of 'import', 'change', 'prune'.
1115-
1116-
"""
1117-
1118-
node = None
1119-
"""
1120-
Either 'output' or an Apply instance.
1121-
1122-
"""
1123-
1124-
op = None
1125-
"""Either 'output' or an Op instance"""
1115+
kind: Literal["import", "change", "prune"]
1116+
old_node: Optional[Union[Literal["output"], "Apply"]]
1117+
new_node: Optional[Union[Literal["output"], "Apply"]]
1118+
op: Optional[Union[Literal["output"], Op]]
1119+
idx: Optional[int]
1120+
reason: Optional[str]
11261121

1127-
idx = None
1128-
"""
1129-
Change events involve an position index of the input variable.
1130-
1131-
"""
1132-
1133-
reason = None
1134-
"""
1135-
Change events sometimes have a reason.
1136-
1137-
"""
1138-
1139-
def __init__(self, kind, node, idx=None, reason=None):
1122+
def __init__(
1123+
self,
1124+
kind: Literal["import", "change", "prune"],
1125+
old_node: Union[Literal["output"], "Apply"],
1126+
new_node: Union[Literal["output"], "Apply"] = None,
1127+
idx: Optional[int] = None,
1128+
reason: Optional[str] = None,
1129+
):
11401130
self.kind = kind
1141-
if node == "output":
1142-
self.node = "output"
1131+
if old_node == "output":
1132+
self.old_node = "output"
1133+
self.new_node = "output"
11431134
self.op = "output"
11441135
else:
1145-
self.node = node
1146-
self.op = node.op
1136+
self.old_node = old_node
1137+
self.new_node = new_node
1138+
self.op = old_node.op
11471139
self.idx = idx
1148-
self.reason = str(reason)
1140+
self.reason = str(reason) if reason else None
11491141

11501142
def __str__(self):
11511143
if self.kind == "change":
@@ -1219,21 +1211,21 @@ def on_attach(self, fgraph):
12191211
self.replaced_by = {}
12201212
self.event_list = []
12211213
for node in fgraph.toposort():
1222-
self.on_import(fgraph, node, "on_attach")
1214+
self.on_import(fgraph, node, reason="on_attach")
12231215

12241216
def on_detach(self, fgraph):
12251217
assert fgraph is self.fgraph
12261218
self.fgraph = None
12271219

12281220
def on_prune(self, fgraph, node, reason):
1229-
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
1221+
self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason))
12301222
assert node in self.active_nodes
12311223
assert node not in self.inactive_nodes
12321224
self.active_nodes.remove(node)
12331225
self.inactive_nodes.add(node)
12341226

12351227
def on_import(self, fgraph, node, reason):
1236-
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
1228+
self.event_list.append(_FunctionGraphEvent("import", node, reason=reason))
12371229

12381230
assert node not in self.active_nodes
12391231
self.active_nodes.add(node)
@@ -1253,31 +1245,36 @@ def on_import(self, fgraph, node, reason):
12531245
self.reasons.setdefault(r, [])
12541246
self.replaced_by.setdefault(r, [])
12551247

1256-
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1248+
def on_change_input(
1249+
self, fgraph, old_node, new_node, i, old_var, new_var, reason=None
1250+
):
12571251
reason = str(reason)
12581252
self.event_list.append(
1259-
_FunctionGraphEvent("change", node, reason=reason, idx=i)
1253+
_FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason)
12601254
)
12611255

1262-
self.reasons.setdefault(new_r, [])
1263-
self.replaced_by.setdefault(new_r, [])
1256+
self.on_import(fgraph, new_node, reason=reason)
1257+
self.on_prune(fgraph, old_node, reason=reason)
1258+
1259+
self.reasons.setdefault(new_var, [])
1260+
self.replaced_by.setdefault(new_var, [])
12641261

12651262
append_reason = True
1266-
for tup in self.reasons[new_r]:
1267-
if tup[0] == reason and tup[1] is r:
1263+
for tup in self.reasons[new_var]:
1264+
if tup[0] == reason and tup[1] is old_var:
12681265
append_reason = False
12691266

12701267
if append_reason:
12711268
# N.B. compute the debugprint now, because future
12721269
# optimizations will change the graph
12731270
done = dict()
12741271
used_ids = dict()
1275-
self.reasons[new_r].append(
1272+
self.reasons[new_var].append(
12761273
(
12771274
reason,
1278-
r,
1275+
old_var,
12791276
_debugprint(
1280-
r,
1277+
old_var,
12811278
prefix=" ",
12821279
depth=6,
12831280
file=StringIO(),
@@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12861283
used_ids=used_ids,
12871284
).getvalue(),
12881285
_debugprint(
1289-
new_r,
1286+
new_var,
12901287
prefix=" ",
12911288
depth=6,
12921289
file=StringIO(),
@@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12961293
).getvalue(),
12971294
)
12981295
)
1299-
self.replaced_by[r].append((reason, new_r))
1296+
self.replaced_by[old_var].append((reason, new_var))
13001297

1301-
if r in self.equiv:
1302-
r_set = self.equiv[r]
1298+
if old_var in self.equiv:
1299+
r_set = self.equiv[old_var]
13031300
else:
1304-
r_set = self.equiv.setdefault(r, {r})
1305-
self.all_variables_ever.append(r)
1301+
r_set = self.equiv.setdefault(old_var, {old_var})
1302+
self.all_variables_ever.append(old_var)
13061303

1307-
if new_r in self.equiv:
1308-
new_r_set = self.equiv[new_r]
1304+
if new_var in self.equiv:
1305+
new_r_set = self.equiv[new_var]
13091306
else:
1310-
new_r_set = self.equiv.setdefault(new_r, {new_r})
1311-
self.all_variables_ever.append(new_r)
1307+
new_r_set = self.equiv.setdefault(new_var, {new_var})
1308+
self.all_variables_ever.append(new_var)
13121309

1313-
assert new_r in new_r_set
1314-
assert r in r_set
1310+
assert new_var in new_r_set
1311+
assert old_var in r_set
13151312

13161313
# update one equivalence set to contain the other
13171314
# transfer all the elements of the old one to the new one
@@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13201317
self.equiv[like_new_r] = r_set
13211318
assert like_new_r in r_set
13221319

1323-
assert self.equiv[r] is r_set
1324-
assert self.equiv[new_r] is r_set
1320+
assert self.equiv[old_var] is r_set
1321+
assert self.equiv[new_var] is r_set
13251322

13261323
def printstuff(self):
13271324
for key in self.equiv:

0 commit comments

Comments
 (0)