Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,3 @@ def _load_boolean_flag(
this_will="trace all traceable functions to fold if branches and collapse constant expressions",
default=True,
)
EXPERIMENTAL_USE_IR: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_USE_IR",
this_will="use the ONNX IR instead of the PyTorch Graph for graph building",
deprecated=True,
)
20 changes: 6 additions & 14 deletions onnxscript/function_libs/torch_lib/graph_building/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,9 @@
"TorchScriptTracingEvaluator",
]

from onnxscript.function_libs.torch_lib import _flags

if _flags.EXPERIMENTAL_USE_IR:
from ._graph_building_ir import (
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)
else:
from ._graph_building_torch import ( # type: ignore[assignment]
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)

from ._graph_building_ir import (
TorchScriptGraph,
TorchScriptTensor,
TorchScriptTracingEvaluator,
)
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
Expand Down Expand Up @@ -269,6 +266,10 @@ def eval_function( # type: ignore[override]
value, float
):
attributes[name] = (value,)
if function.traceable:
inputs = self._graph.precprocess_inputs(inputs, attributes)
# Trace the function call instead of adding the function as a node
return function.function(*inputs, **attributes)
return self._graph.add_function_call(function, inputs, attributes)


Expand Down Expand Up @@ -522,15 +523,11 @@ def _add_constant_to_graph(self, constant) -> Sequence[ir.Value | None]:
)
return value

def _add_ir_graph_op_call(
def precprocess_inputs(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
) -> list[TorchScriptTensor]:
graph_inputs: list[TorchScriptTensor] = []
assert isinstance(onnx_inputs, Sequence)
for input in onnx_inputs:
Expand Down Expand Up @@ -559,6 +556,18 @@ def _add_ir_graph_op_call(
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
return graph_inputs

def _add_ir_graph_op_call(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
graph_inputs = self.precprocess_inputs(onnx_inputs, onnx_attributes)
tensors = _create_op_call_in_graph(
self._graph,
domain,
Expand Down
Loading