Skip to content

Commit 3c665a5

Browse files
Separate recursive importing from single node importing in FunctionGraph
1 parent 1cce2b0 commit 3c665a5

File tree

1 file changed

+57
-37
lines changed

1 file changed

+57
-37
lines changed

aesara/graph/fg.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def import_var(
311311

312312
if isinstance(var.type, NullType):
313313
raise TypeError(
314-
f"Computation graph contains a NaN. {var.type.why_null}"
314+
f"Computation graph contains a null type: {var} {var.type.why_null}"
315315
)
316316
if import_missing:
317317
self.add_input(var)
@@ -327,7 +327,7 @@ def import_node(
327327
reason: Optional[str] = None,
328328
import_missing: bool = False,
329329
) -> None:
330-
"""Recursively import everything between an ``Apply`` node and the ``FunctionGraph``'s outputs.
330+
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
331331
332332
Parameters
333333
----------
@@ -347,42 +347,62 @@ def import_node(
347347
# to know where to stop going down.)
348348
new_nodes = io_toposort(self.variables, apply_node.outputs)
349349

350-
if check:
351-
for node in new_nodes:
352-
for var in node.inputs:
353-
if (
354-
var.owner is None
355-
and not isinstance(var, AtomicVariable)
356-
and var not in self.inputs
357-
):
358-
if import_missing:
359-
self.add_input(var)
360-
else:
361-
error_msg = (
362-
f"Input {node.inputs.index(var)} ({var})"
363-
" of the graph (indices start "
364-
f"from 0), used to compute {node}, was not "
365-
"provided and not given a value. Use the "
366-
"Aesara flag exception_verbosity='high', "
367-
"for more information on this error."
368-
)
369-
raise MissingInputError(error_msg, variable=var)
370-
371350
for node in new_nodes:
372-
assert node not in self.apply_nodes
373-
self.apply_nodes.add(node)
374-
if not hasattr(node.tag, "imported_by"):
375-
node.tag.imported_by = []
376-
node.tag.imported_by.append(str(reason))
377-
for output in node.outputs:
378-
self.setup_var(output)
379-
self.variables.add(output)
380-
for i, input in enumerate(node.inputs):
381-
if input not in self.variables:
382-
self.setup_var(input)
383-
self.variables.add(input)
384-
self.add_client(input, (node, i))
385-
self.execute_callbacks("on_import", node, reason)
351+
self._import_node(
352+
node, check=check, reason=reason, import_missing=import_missing
353+
)
354+
355+
def _import_node(
356+
self,
357+
apply_node: Apply,
358+
check: bool = True,
359+
reason: Optional[str] = None,
360+
import_missing: bool = False,
361+
) -> None:
362+
"""Import a single node.
363+
364+
See `FunctionGraph.import_node`.
365+
"""
366+
assert apply_node not in self.apply_nodes
367+
368+
for i, inp in enumerate(apply_node.inputs):
369+
if (
370+
check
371+
and inp.owner is None
372+
and not isinstance(inp, AtomicVariable)
373+
and inp not in self.inputs
374+
):
375+
if import_missing:
376+
self.add_input(inp)
377+
else:
378+
error_msg = (
379+
f"Input {apply_node.inputs.index(inp)} ({inp})"
380+
" of the graph (indices start "
381+
f"from 0), used to compute {apply_node}, was not "
382+
"provided and not given a value. Use the "
383+
"Aesara flag exception_verbosity='high', "
384+
"for more information on this error."
385+
)
386+
raise MissingInputError(error_msg, variable=inp)
387+
388+
if inp not in self.variables:
389+
self.setup_var(inp)
390+
self.variables.add(inp)
391+
392+
self.add_client(inp, (apply_node, i))
393+
394+
for output in apply_node.outputs:
395+
self.setup_var(output)
396+
self.variables.add(output)
397+
398+
self.apply_nodes.add(apply_node)
399+
400+
if not hasattr(apply_node.tag, "imported_by"):
401+
apply_node.tag.imported_by = []
402+
403+
apply_node.tag.imported_by.append(str(reason))
404+
405+
self.execute_callbacks("on_import", apply_node, reason)
386406

387407
def change_node_input(
388408
self,

0 commit comments

Comments
 (0)