@@ -311,7 +311,7 @@ def import_var(
311
311
312
312
if isinstance (var .type , NullType ):
313
313
raise TypeError (
314
- f"Computation graph contains a NaN. { var .type .why_null } "
314
+ f"Computation graph contains a null type: { var } { var .type .why_null } "
315
315
)
316
316
if import_missing :
317
317
self .add_input (var )
@@ -327,7 +327,7 @@ def import_node(
327
327
reason : Optional [str ] = None ,
328
328
import_missing : bool = False ,
329
329
) -> None :
330
- """Recursively import everything between an `` Apply`` node and the `` FunctionGraph` `'s outputs.
330
+ """Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
331
331
332
332
Parameters
333
333
----------
@@ -347,42 +347,62 @@ def import_node(
347
347
# to know where to stop going down.)
348
348
new_nodes = io_toposort (self .variables , apply_node .outputs )
349
349
350
- if check :
351
- for node in new_nodes :
352
- for var in node .inputs :
353
- if (
354
- var .owner is None
355
- and not isinstance (var , AtomicVariable )
356
- and var not in self .inputs
357
- ):
358
- if import_missing :
359
- self .add_input (var )
360
- else :
361
- error_msg = (
362
- f"Input { node .inputs .index (var )} ({ var } )"
363
- " of the graph (indices start "
364
- f"from 0), used to compute { node } , was not "
365
- "provided and not given a value. Use the "
366
- "Aesara flag exception_verbosity='high', "
367
- "for more information on this error."
368
- )
369
- raise MissingInputError (error_msg , variable = var )
370
-
371
350
for node in new_nodes :
372
- assert node not in self .apply_nodes
373
- self .apply_nodes .add (node )
374
- if not hasattr (node .tag , "imported_by" ):
375
- node .tag .imported_by = []
376
- node .tag .imported_by .append (str (reason ))
377
- for output in node .outputs :
378
- self .setup_var (output )
379
- self .variables .add (output )
380
- for i , input in enumerate (node .inputs ):
381
- if input not in self .variables :
382
- self .setup_var (input )
383
- self .variables .add (input )
384
- self .add_client (input , (node , i ))
385
- self .execute_callbacks ("on_import" , node , reason )
351
+ self ._import_node (
352
+ node , check = check , reason = reason , import_missing = import_missing
353
+ )
354
+
355
+ def _import_node (
356
+ self ,
357
+ apply_node : Apply ,
358
+ check : bool = True ,
359
+ reason : Optional [str ] = None ,
360
+ import_missing : bool = False ,
361
+ ) -> None :
362
+ """Import a single node.
363
+
364
+ See `FunctionGraph.import_node`.
365
+ """
366
+ assert apply_node not in self .apply_nodes
367
+
368
+ for i , inp in enumerate (apply_node .inputs ):
369
+ if (
370
+ check
371
+ and inp .owner is None
372
+ and not isinstance (inp , AtomicVariable )
373
+ and inp not in self .inputs
374
+ ):
375
+ if import_missing :
376
+ self .add_input (inp )
377
+ else :
378
+ error_msg = (
379
+ f"Input { apply_node .inputs .index (inp )} ({ inp } )"
380
+ " of the graph (indices start "
381
+ f"from 0), used to compute { apply_node } , was not "
382
+ "provided and not given a value. Use the "
383
+ "Aesara flag exception_verbosity='high', "
384
+ "for more information on this error."
385
+ )
386
+ raise MissingInputError (error_msg , variable = inp )
387
+
388
+ if inp not in self .variables :
389
+ self .setup_var (inp )
390
+ self .variables .add (inp )
391
+
392
+ self .add_client (inp , (apply_node , i ))
393
+
394
+ for output in apply_node .outputs :
395
+ self .setup_var (output )
396
+ self .variables .add (output )
397
+
398
+ self .apply_nodes .add (apply_node )
399
+
400
+ if not hasattr (apply_node .tag , "imported_by" ):
401
+ apply_node .tag .imported_by = []
402
+
403
+ apply_node .tag .imported_by .append (str (reason ))
404
+
405
+ self .execute_callbacks ("on_import" , apply_node , reason )
386
406
387
407
def change_node_input (
388
408
self ,
0 commit comments