14
14
from itertools import chain
15
15
from itertools import product as itertools_product
16
16
from logging import Logger
17
- from typing import Optional
17
+ from typing import TYPE_CHECKING , Optional , Union
18
18
from warnings import warn
19
19
20
20
import numpy as np
21
+ from typing_extensions import Literal
21
22
22
23
import aesara
23
24
from aesara .compile .function .types import (
42
43
from aesara .utils import NoDuplicateOptWarningFilter , difference , get_unbound_function
43
44
44
45
45
- __docformat__ = "restructuredtext en"
46
+ if TYPE_CHECKING :
47
+ from aesara .graph .basic import Apply
48
+
46
49
_logger : Logger = logging .getLogger ("aesara.compile.debugmode" )
47
50
_logger .addFilter (NoDuplicateOptWarningFilter ())
48
51
@@ -1108,43 +1111,32 @@ class _FunctionGraphEvent:
1108
1111
1109
1112
"""
1110
1113
1111
- kind = ""
1112
- """
1113
- One of 'import', 'change', 'prune'.
1114
-
1115
- """
1116
-
1117
- node = None
1118
- """
1119
- Either 'output' or an Apply instance.
1120
-
1121
- """
1122
-
1123
- op = None
1124
- """Either 'output' or an Op instance"""
1114
+ kind : Literal ["import" , "change" , "prune" ]
1115
+ old_node : Optional [Union [Literal ["output" ], "Apply" ]]
1116
+ new_node : Optional [Union [Literal ["output" ], "Apply" ]]
1117
+ op : Optional [Union [Literal ["output" ], Op ]]
1118
+ idx : Optional [int ]
1119
+ reason : Optional [str ]
1125
1120
1126
- idx = None
1127
- """
1128
- Change events involve an position index of the input variable.
1129
-
1130
- """
1131
-
1132
- reason = None
1133
- """
1134
- Change events sometimes have a reason.
1135
-
1136
- """
1137
-
1138
- def __init__ (self , kind , node , idx = None , reason = None ):
1121
+ def __init__ (
1122
+ self ,
1123
+ kind : Literal ["import" , "change" , "prune" ],
1124
+ old_node : Union [Literal ["output" ], "Apply" ],
1125
+ new_node : Union [Literal ["output" ], "Apply" ] = None ,
1126
+ idx : Optional [int ] = None ,
1127
+ reason : Optional [str ] = None ,
1128
+ ):
1139
1129
self .kind = kind
1140
- if node == "output" :
1141
- self .node = "output"
1130
+ if old_node == "output" :
1131
+ self .old_node = "output"
1132
+ self .new_node = "output"
1142
1133
self .op = "output"
1143
1134
else :
1144
- self .node = node
1145
- self .op = node .op
1135
+ self .old_node = old_node
1136
+ self .new_node = new_node
1137
+ self .op = old_node .op
1146
1138
self .idx = idx
1147
- self .reason = str (reason )
1139
+ self .reason = str (reason ) if reason else None
1148
1140
1149
1141
def __str__ (self ):
1150
1142
if self .kind == "change" :
@@ -1218,21 +1210,21 @@ def on_attach(self, fgraph):
1218
1210
self .replaced_by = {}
1219
1211
self .event_list = []
1220
1212
for node in fgraph .toposort ():
1221
- self .on_import (fgraph , node , "on_attach" )
1213
+ self .on_import (fgraph , node , reason = "on_attach" )
1222
1214
1223
1215
def on_detach (self , fgraph ):
1224
1216
assert fgraph is self .fgraph
1225
1217
self .fgraph = None
1226
1218
1227
1219
def on_prune (self , fgraph , node , reason ):
1228
- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str ( reason ) ))
1220
+ self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = reason ))
1229
1221
assert node in self .active_nodes
1230
1222
assert node not in self .inactive_nodes
1231
1223
self .active_nodes .remove (node )
1232
1224
self .inactive_nodes .add (node )
1233
1225
1234
1226
def on_import (self , fgraph , node , reason ):
1235
- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str ( reason ) ))
1227
+ self .event_list .append (_FunctionGraphEvent ("import" , node , reason = reason ))
1236
1228
1237
1229
assert node not in self .active_nodes
1238
1230
self .active_nodes .add (node )
@@ -1252,31 +1244,36 @@ def on_import(self, fgraph, node, reason):
1252
1244
self .reasons .setdefault (r , [])
1253
1245
self .replaced_by .setdefault (r , [])
1254
1246
1255
- def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1247
+ def on_change_input (
1248
+ self , fgraph , old_node , new_node , i , old_var , new_var , reason = None
1249
+ ):
1256
1250
reason = str (reason )
1257
1251
self .event_list .append (
1258
- _FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1252
+ _FunctionGraphEvent ("change" , old_node , new_node , idx = i , reason = reason )
1259
1253
)
1260
1254
1261
- self .reasons .setdefault (new_r , [])
1262
- self .replaced_by .setdefault (new_r , [])
1255
+ self .on_import (fgraph , new_node , reason = reason )
1256
+ self .on_prune (fgraph , old_node , reason = reason )
1257
+
1258
+ self .reasons .setdefault (new_var , [])
1259
+ self .replaced_by .setdefault (new_var , [])
1263
1260
1264
1261
append_reason = True
1265
- for tup in self .reasons [new_r ]:
1266
- if tup [0 ] == reason and tup [1 ] is r :
1262
+ for tup in self .reasons [new_var ]:
1263
+ if tup [0 ] == reason and tup [1 ] is old_var :
1267
1264
append_reason = False
1268
1265
1269
1266
if append_reason :
1270
1267
# N.B. compute the debugprint now, because future
1271
1268
# optimizations will change the graph
1272
1269
done = dict ()
1273
1270
used_ids = dict ()
1274
- self .reasons [new_r ].append (
1271
+ self .reasons [new_var ].append (
1275
1272
(
1276
1273
reason ,
1277
- r ,
1274
+ old_var ,
1278
1275
_debugprint (
1279
- r ,
1276
+ old_var ,
1280
1277
prefix = " " ,
1281
1278
depth = 6 ,
1282
1279
file = StringIO (),
@@ -1285,7 +1282,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1285
1282
used_ids = used_ids ,
1286
1283
).getvalue (),
1287
1284
_debugprint (
1288
- new_r ,
1285
+ new_var ,
1289
1286
prefix = " " ,
1290
1287
depth = 6 ,
1291
1288
file = StringIO (),
@@ -1295,22 +1292,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1295
1292
).getvalue (),
1296
1293
)
1297
1294
)
1298
- self .replaced_by [r ].append ((reason , new_r ))
1295
+ self .replaced_by [old_var ].append ((reason , new_var ))
1299
1296
1300
- if r in self .equiv :
1301
- r_set = self .equiv [r ]
1297
+ if old_var in self .equiv :
1298
+ r_set = self .equiv [old_var ]
1302
1299
else :
1303
- r_set = self .equiv .setdefault (r , {r })
1304
- self .all_variables_ever .append (r )
1300
+ r_set = self .equiv .setdefault (old_var , {old_var })
1301
+ self .all_variables_ever .append (old_var )
1305
1302
1306
- if new_r in self .equiv :
1307
- new_r_set = self .equiv [new_r ]
1303
+ if new_var in self .equiv :
1304
+ new_r_set = self .equiv [new_var ]
1308
1305
else :
1309
- new_r_set = self .equiv .setdefault (new_r , {new_r })
1310
- self .all_variables_ever .append (new_r )
1306
+ new_r_set = self .equiv .setdefault (new_var , {new_var })
1307
+ self .all_variables_ever .append (new_var )
1311
1308
1312
- assert new_r in new_r_set
1313
- assert r in r_set
1309
+ assert new_var in new_r_set
1310
+ assert old_var in r_set
1314
1311
1315
1312
# update one equivalence set to contain the other
1316
1313
# transfer all the elements of the old one to the new one
@@ -1319,8 +1316,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1319
1316
self .equiv [like_new_r ] = r_set
1320
1317
assert like_new_r in r_set
1321
1318
1322
- assert self .equiv [r ] is r_set
1323
- assert self .equiv [new_r ] is r_set
1319
+ assert self .equiv [old_var ] is r_set
1320
+ assert self .equiv [new_var ] is r_set
1324
1321
1325
1322
def printstuff (self ):
1326
1323
for key in self .equiv :
0 commit comments