Skip to content
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@
else:
# Python constants are scalars
return 0
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

# args/kwargs are TorchScriptTensor/python built-in based
param_schemas = function.param_schemas()
Expand Down Expand Up @@ -269,6 +266,10 @@
value, float
):
attributes[name] = (value,)
if function.traceable:
inputs = self._graph.precprocess_inputs(inputs, attributes)

Check warning on line 270 in onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py#L270

Added line #L270 was not covered by tests
# Trace the function call instead of adding the function as a node
return function.function(*inputs, **attributes)

Check warning on line 272 in onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py#L272

Added line #L272 was not covered by tests
return self._graph.add_function_call(function, inputs, attributes)


Expand Down Expand Up @@ -522,15 +523,11 @@
)
return value

def _add_ir_graph_op_call(
def precprocess_inputs(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
) -> list[TorchScriptTensor]:
graph_inputs: list[TorchScriptTensor] = []
assert isinstance(onnx_inputs, Sequence)
for input in onnx_inputs:
Expand Down Expand Up @@ -559,6 +556,18 @@
assert not isinstance(
value, TorchScriptTensor
), f"ONNX attribute must not be a TorchScriptTensor, got {key}: {value}."
return graph_inputs

Check warning on line 559 in onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py#L559

Added line #L559 was not covered by tests

def _add_ir_graph_op_call(
self,
*,
domain: str,
op_type: str,
onnx_inputs: Sequence[ValidInputType],
onnx_attributes: Mapping[str, ValidArgumentType],
num_outputs: int,
) -> Sequence[TorchScriptTensor]:
graph_inputs = self.precprocess_inputs(onnx_inputs, onnx_attributes)

Check warning on line 570 in onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py#L570

Added line #L570 was not covered by tests
tensors = _create_op_call_in_graph(
self._graph,
domain,
Expand Down
Loading