14
14
from itertools import chain
15
15
from itertools import product as itertools_product
16
16
from logging import Logger
17
- from typing import Optional
17
+ from typing import TYPE_CHECKING , Optional , Union
18
18
from warnings import warn
19
19
20
20
import numpy as np
21
+ from typing_extensions import Literal
21
22
22
23
import aesara
23
24
from aesara .compile .function .types import (
42
43
from aesara .utils import NoDuplicateOptWarningFilter , difference , get_unbound_function
43
44
44
45
45
- __docformat__ = "restructuredtext en"
46
+ if TYPE_CHECKING :
47
+ from aesara .graph .basic import Apply
48
+
46
49
_logger : Logger = logging .getLogger ("aesara.compile.debugmode" )
47
50
_logger .addFilter (NoDuplicateOptWarningFilter ())
48
51
@@ -1109,43 +1112,32 @@ class _FunctionGraphEvent:
1109
1112
1110
1113
"""
1111
1114
1112
- kind = ""
1113
- """
1114
- One of 'import', 'change', 'prune'.
1115
-
1116
- """
1117
-
1118
- node = None
1119
- """
1120
- Either 'output' or an Apply instance.
1121
-
1122
- """
1123
-
1124
- op = None
1125
- """Either 'output' or an Op instance"""
1115
+ kind : Literal ["import" , "change" , "prune" ]
1116
+ old_node : Optional [Union [Literal ["output" ], "Apply" ]]
1117
+ new_node : Optional [Union [Literal ["output" ], "Apply" ]]
1118
+ op : Optional [Union [Literal ["output" ], Op ]]
1119
+ idx : Optional [int ]
1120
+ reason : Optional [str ]
1126
1121
1127
- idx = None
1128
- """
1129
- Change events involve an position index of the input variable.
1130
-
1131
- """
1132
-
1133
- reason = None
1134
- """
1135
- Change events sometimes have a reason.
1136
-
1137
- """
1138
-
1139
- def __init__ (self , kind , node , idx = None , reason = None ):
1122
+ def __init__ (
1123
+ self ,
1124
+ kind : Literal ["import" , "change" , "prune" ],
1125
+ old_node : Union [Literal ["output" ], "Apply" ],
1126
+ new_node : Union [Literal ["output" ], "Apply" ] = None ,
1127
+ idx : Optional [int ] = None ,
1128
+ reason : Optional [str ] = None ,
1129
+ ):
1140
1130
self .kind = kind
1141
- if node == "output" :
1142
- self .node = "output"
1131
+ if old_node == "output" :
1132
+ self .old_node = "output"
1133
+ self .new_node = "output"
1143
1134
self .op = "output"
1144
1135
else :
1145
- self .node = node
1146
- self .op = node .op
1136
+ self .old_node = old_node
1137
+ self .new_node = new_node
1138
+ self .op = old_node .op
1147
1139
self .idx = idx
1148
- self .reason = str (reason )
1140
+ self .reason = str (reason ) if reason else None
1149
1141
1150
1142
def __str__ (self ):
1151
1143
if self .kind == "change" :
@@ -1219,21 +1211,21 @@ def on_attach(self, fgraph):
1219
1211
self .replaced_by = {}
1220
1212
self .event_list = []
1221
1213
for node in fgraph .toposort ():
1222
- self .on_import (fgraph , node , "on_attach" )
1214
+ self .on_import (fgraph , node , reason = "on_attach" )
1223
1215
1224
1216
def on_detach (self , fgraph ):
1225
1217
assert fgraph is self .fgraph
1226
1218
self .fgraph = None
1227
1219
1228
1220
def on_prune (self , fgraph , node , reason ):
1229
- self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = str ( reason ) ))
1221
+ self .event_list .append (_FunctionGraphEvent ("prune" , node , reason = reason ))
1230
1222
assert node in self .active_nodes
1231
1223
assert node not in self .inactive_nodes
1232
1224
self .active_nodes .remove (node )
1233
1225
self .inactive_nodes .add (node )
1234
1226
1235
1227
def on_import (self , fgraph , node , reason ):
1236
- self .event_list .append (_FunctionGraphEvent ("import" , node , reason = str ( reason ) ))
1228
+ self .event_list .append (_FunctionGraphEvent ("import" , node , reason = reason ))
1237
1229
1238
1230
assert node not in self .active_nodes
1239
1231
self .active_nodes .add (node )
@@ -1253,31 +1245,36 @@ def on_import(self, fgraph, node, reason):
1253
1245
self .reasons .setdefault (r , [])
1254
1246
self .replaced_by .setdefault (r , [])
1255
1247
1256
- def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
1248
+ def on_change_input (
1249
+ self , fgraph , old_node , new_node , i , old_var , new_var , reason = None
1250
+ ):
1257
1251
reason = str (reason )
1258
1252
self .event_list .append (
1259
- _FunctionGraphEvent ("change" , node , reason = reason , idx = i )
1253
+ _FunctionGraphEvent ("change" , old_node , new_node , idx = i , reason = reason )
1260
1254
)
1261
1255
1262
- self .reasons .setdefault (new_r , [])
1263
- self .replaced_by .setdefault (new_r , [])
1256
+ self .on_import (fgraph , new_node , reason = reason )
1257
+ self .on_prune (fgraph , old_node , reason = reason )
1258
+
1259
+ self .reasons .setdefault (new_var , [])
1260
+ self .replaced_by .setdefault (new_var , [])
1264
1261
1265
1262
append_reason = True
1266
- for tup in self .reasons [new_r ]:
1267
- if tup [0 ] == reason and tup [1 ] is r :
1263
+ for tup in self .reasons [new_var ]:
1264
+ if tup [0 ] == reason and tup [1 ] is old_var :
1268
1265
append_reason = False
1269
1266
1270
1267
if append_reason :
1271
1268
# N.B. compute the debugprint now, because future
1272
1269
# optimizations will change the graph
1273
1270
done = dict ()
1274
1271
used_ids = dict ()
1275
- self .reasons [new_r ].append (
1272
+ self .reasons [new_var ].append (
1276
1273
(
1277
1274
reason ,
1278
- r ,
1275
+ old_var ,
1279
1276
_debugprint (
1280
- r ,
1277
+ old_var ,
1281
1278
prefix = " " ,
1282
1279
depth = 6 ,
1283
1280
file = StringIO (),
@@ -1286,7 +1283,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1286
1283
used_ids = used_ids ,
1287
1284
).getvalue (),
1288
1285
_debugprint (
1289
- new_r ,
1286
+ new_var ,
1290
1287
prefix = " " ,
1291
1288
depth = 6 ,
1292
1289
file = StringIO (),
@@ -1296,22 +1293,22 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1296
1293
).getvalue (),
1297
1294
)
1298
1295
)
1299
- self .replaced_by [r ].append ((reason , new_r ))
1296
+ self .replaced_by [old_var ].append ((reason , new_var ))
1300
1297
1301
- if r in self .equiv :
1302
- r_set = self .equiv [r ]
1298
+ if old_var in self .equiv :
1299
+ r_set = self .equiv [old_var ]
1303
1300
else :
1304
- r_set = self .equiv .setdefault (r , {r })
1305
- self .all_variables_ever .append (r )
1301
+ r_set = self .equiv .setdefault (old_var , {old_var })
1302
+ self .all_variables_ever .append (old_var )
1306
1303
1307
- if new_r in self .equiv :
1308
- new_r_set = self .equiv [new_r ]
1304
+ if new_var in self .equiv :
1305
+ new_r_set = self .equiv [new_var ]
1309
1306
else :
1310
- new_r_set = self .equiv .setdefault (new_r , {new_r })
1311
- self .all_variables_ever .append (new_r )
1307
+ new_r_set = self .equiv .setdefault (new_var , {new_var })
1308
+ self .all_variables_ever .append (new_var )
1312
1309
1313
- assert new_r in new_r_set
1314
- assert r in r_set
1310
+ assert new_var in new_r_set
1311
+ assert old_var in r_set
1315
1312
1316
1313
# update one equivalence set to contain the other
1317
1314
# transfer all the elements of the old one to the new one
@@ -1320,8 +1317,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
1320
1317
self .equiv [like_new_r ] = r_set
1321
1318
assert like_new_r in r_set
1322
1319
1323
- assert self .equiv [r ] is r_set
1324
- assert self .equiv [new_r ] is r_set
1320
+ assert self .equiv [old_var ] is r_set
1321
+ assert self .equiv [new_var ] is r_set
1325
1322
1326
1323
def printstuff (self ):
1327
1324
for key in self .equiv :
0 commit comments