Skip to content

Commit f4a9e86

Browse files
Hash-cons Apply, Constant and change node input replacement semantics
1 parent 8fb9d9b commit f4a9e86

File tree

18 files changed

+742
-679
lines changed

18 files changed

+742
-679
lines changed

aesara/compile/debugmode.py

Lines changed: 57 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from itertools import chain
1515
from itertools import product as itertools_product
1616
from logging import Logger
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Optional, Union
1818
from warnings import warn
1919

2020
import numpy as np
21+
from typing_extensions import Literal
2122

2223
import aesara
2324
from aesara.compile.function.types import (
@@ -42,7 +43,9 @@
4243
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
4344

4445

45-
__docformat__ = "restructuredtext en"
46+
if TYPE_CHECKING:
47+
from aesara.graph.basic import Apply
48+
4649
_logger: Logger = logging.getLogger("aesara.compile.debugmode")
4750
_logger.addFilter(NoDuplicateOptWarningFilter())
4851

@@ -1108,43 +1111,32 @@ class _FunctionGraphEvent:
11081111
11091112
"""
11101113

1111-
kind = ""
1112-
"""
1113-
One of 'import', 'change', 'prune'.
1114-
1115-
"""
1116-
1117-
node = None
1118-
"""
1119-
Either 'output' or an Apply instance.
1120-
1121-
"""
1122-
1123-
op = None
1124-
"""Either 'output' or an Op instance"""
1114+
kind: Literal["import", "change", "prune"]
1115+
old_node: Optional[Union[Literal["output"], "Apply"]]
1116+
new_node: Optional[Union[Literal["output"], "Apply"]]
1117+
op: Optional[Union[Literal["output"], Op]]
1118+
idx: Optional[int]
1119+
reason: Optional[str]
11251120

1126-
idx = None
1127-
"""
1128-
Change events involve an position index of the input variable.
1129-
1130-
"""
1131-
1132-
reason = None
1133-
"""
1134-
Change events sometimes have a reason.
1135-
1136-
"""
1137-
1138-
def __init__(self, kind, node, idx=None, reason=None):
1121+
def __init__(
1122+
self,
1123+
kind: Literal["import", "change", "prune"],
1124+
old_node: Union[Literal["output"], "Apply"],
1125+
new_node: Union[Literal["output"], "Apply"] = None,
1126+
idx: Optional[int] = None,
1127+
reason: Optional[str] = None,
1128+
):
11391129
self.kind = kind
1140-
if node == "output":
1141-
self.node = "output"
1130+
if old_node == "output":
1131+
self.old_node = "output"
1132+
self.new_node = "output"
11421133
self.op = "output"
11431134
else:
1144-
self.node = node
1145-
self.op = node.op
1135+
self.old_node = old_node
1136+
self.new_node = new_node
1137+
self.op = old_node.op
11461138
self.idx = idx
1147-
self.reason = str(reason)
1139+
self.reason = str(reason) if reason else None
11481140

11491141
def __str__(self):
11501142
if self.kind == "change":
@@ -1218,21 +1210,21 @@ def on_attach(self, fgraph):
12181210
self.replaced_by = {}
12191211
self.event_list = []
12201212
for node in fgraph.toposort():
1221-
self.on_import(fgraph, node, "on_attach")
1213+
self.on_import(fgraph, node, reason="on_attach")
12221214

12231215
def on_detach(self, fgraph):
12241216
assert fgraph is self.fgraph
12251217
self.fgraph = None
12261218

12271219
def on_prune(self, fgraph, node, reason):
1228-
self.event_list.append(_FunctionGraphEvent("prune", node, reason=str(reason)))
1220+
self.event_list.append(_FunctionGraphEvent("prune", node, reason=reason))
12291221
assert node in self.active_nodes
12301222
assert node not in self.inactive_nodes
12311223
self.active_nodes.remove(node)
12321224
self.inactive_nodes.add(node)
12331225

12341226
def on_import(self, fgraph, node, reason):
1235-
self.event_list.append(_FunctionGraphEvent("import", node, reason=str(reason)))
1227+
self.event_list.append(_FunctionGraphEvent("import", node, reason=reason))
12361228

12371229
assert node not in self.active_nodes
12381230
self.active_nodes.add(node)
@@ -1252,31 +1244,36 @@ def on_import(self, fgraph, node, reason):
12521244
self.reasons.setdefault(r, [])
12531245
self.replaced_by.setdefault(r, [])
12541246

1255-
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1247+
def on_change_input(
1248+
self, fgraph, old_node, new_node, i, old_var, new_var, reason=None
1249+
):
12561250
reason = str(reason)
12571251
self.event_list.append(
1258-
_FunctionGraphEvent("change", node, reason=reason, idx=i)
1252+
_FunctionGraphEvent("change", old_node, new_node, idx=i, reason=reason)
12591253
)
12601254

1261-
self.reasons.setdefault(new_r, [])
1262-
self.replaced_by.setdefault(new_r, [])
1255+
self.on_import(fgraph, new_node, reason=reason)
1256+
self.on_prune(fgraph, old_node, reason=reason)
1257+
1258+
self.reasons.setdefault(new_var, [])
1259+
self.replaced_by.setdefault(new_var, [])
12631260

12641261
append_reason = True
1265-
for tup in self.reasons[new_r]:
1266-
if tup[0] == reason and tup[1] is r:
1262+
for tup in self.reasons[new_var]:
1263+
if tup[0] == reason and tup[1] is old_var:
12671264
append_reason = False
12681265

12691266
if append_reason:
12701267
# N.B. compute the debugprint now, because future
12711268
# optimizations will change the graph
12721269
done = dict()
12731270
used_ids = dict()
1274-
self.reasons[new_r].append(
1271+
self.reasons[new_var].append(
12751272
(
12761273
reason,
1277-
r,
1274+
old_var,
12781275
_debugprint(
1279-
r,
1276+
old_var,
12801277
prefix=" ",
12811278
depth=6,
12821279
file=StringIO(),
@@ -1285,7 +1282,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12851282
used_ids=used_ids,
12861283
).getvalue(),
12871284
_debugprint(
1288-
new_r,
1285+
new_var,
12891286
prefix=" ",
12901287
depth=6,
12911288
file=StringIO(),
@@ -1295,22 +1292,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
12951292
).getvalue(),
12961293
)
12971294
)
1298-
self.replaced_by[r].append((reason, new_r))
1295+
self.replaced_by[old_var].append((reason, new_var))
12991296

1300-
if r in self.equiv:
1301-
r_set = self.equiv[r]
1297+
if old_var in self.equiv:
1298+
r_set = self.equiv[old_var]
13021299
else:
1303-
r_set = self.equiv.setdefault(r, {r})
1304-
self.all_variables_ever.append(r)
1300+
r_set = self.equiv.setdefault(old_var, {old_var})
1301+
self.all_variables_ever.append(old_var)
13051302

1306-
if new_r in self.equiv:
1307-
new_r_set = self.equiv[new_r]
1303+
if new_var in self.equiv:
1304+
new_r_set = self.equiv[new_var]
13081305
else:
1309-
new_r_set = self.equiv.setdefault(new_r, {new_r})
1310-
self.all_variables_ever.append(new_r)
1306+
new_r_set = self.equiv.setdefault(new_var, {new_var})
1307+
self.all_variables_ever.append(new_var)
13111308

1312-
assert new_r in new_r_set
1313-
assert r in r_set
1309+
assert new_var in new_r_set
1310+
assert old_var in r_set
13141311

13151312
# update one equivalence set to contain the other
13161313
# transfer all the elements of the old one to the new one
@@ -1319,8 +1316,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
13191316
self.equiv[like_new_r] = r_set
13201317
assert like_new_r in r_set
13211318

1322-
assert self.equiv[r] is r_set
1323-
assert self.equiv[new_r] is r_set
1319+
assert self.equiv[old_var] is r_set
1320+
assert self.equiv[new_var] is r_set
13241321

13251322
def printstuff(self):
13261323
for key in self.equiv:

0 commit comments

Comments
 (0)