Skip to content

Commit 561a600

Browse files
authored
[torchlib] Use traced function param schema to process inputs (#1916)
The firs step of #1914, this is setting up onnxscript CI to test whether traced_only function has enough information to process inputs to tensors.
1 parent 2b60939 commit 561a600

File tree

4 files changed

+40
-19
lines changed

4 files changed

+40
-19
lines changed

onnxscript/_internal/param_manipulation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,18 @@ def tag_arguments_with_param_schemas(
131131
raise TypeError(f"Required input/attribute '{param}' was not provided")
132132

133133
return tagged_args, tagged_kwargs
134+
135+
136+
def turn_to_kwargs_to_avoid_ordering(
137+
param_schemas: Sequence[values.ParamSchema],
138+
inputs: list[Any],
139+
attributes: dict[str, Any],
140+
) -> dict[str, Any]:
141+
"""Return the inputs and attributes to the order of the function signature."""
142+
for idx, param in enumerate(param_schemas):
143+
if param.name not in attributes:
144+
if param.is_variadic_input:
145+
attributes[param.name] = inputs[idx:]
146+
elif inputs:
147+
attributes[param.name] = inputs.pop(0)
148+
return attributes

onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,6 @@ def eval_function( # type: ignore[override]
390390
else:
391391
# Python constants are scalars
392392
return 0
393-
elif function.traceable:
394-
# Trace the function call instead of adding the function as a node
395-
return function.function(*args, **kwargs)
396393

397394
# args/kwargs are TorchScriptTensor/python built-in based
398395
param_schemas = function.param_schemas()
@@ -422,6 +419,15 @@ def eval_function( # type: ignore[override]
422419
value, float
423420
):
424421
attributes[name] = (value,)
422+
if function.traceable:
423+
inputs = self._graph.preprocess_inputs(inputs)
424+
inputs = _wrap_torch_value_to_tensor(inputs) # type: ignore[assignment]
425+
# The args and kwargs matters, as it's traced onnx function
426+
kwargs = param_manipulation.turn_to_kwargs_to_avoid_ordering(
427+
param_schemas, inputs, attributes
428+
)
429+
# Trace the function call instead of adding the function as a node
430+
return function.function(**kwargs)
425431
return self._graph.add_function_call(function, inputs, attributes)
426432

427433

@@ -730,14 +736,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
730736
value.setDebugName(_rename_intermediate_value(value.debugName()))
731737
return value
732738

733-
@runtime_typing.checked
734-
def _add_torchscript_op_call(
735-
self,
736-
name: str,
737-
onnx_inputs: Sequence[ValidInputType],
738-
onnx_attributes: Mapping[str, ValidArgumentType],
739-
n_outputs: int,
740-
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
739+
def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]:
741740
unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs)
742741
graph_inputs = []
743742
assert isinstance(unwrapped_inputs, Sequence)
@@ -761,6 +760,17 @@ def _add_torchscript_op_call(
761760
graph_inputs.append(self._add_constant_to_graph(input))
762761
else:
763762
graph_inputs.append(input)
763+
return graph_inputs
764+
765+
@runtime_typing.checked
766+
def _add_torchscript_op_call(
767+
self,
768+
name: str,
769+
onnx_inputs: Sequence[ValidInputType],
770+
onnx_attributes: Mapping[str, ValidArgumentType],
771+
n_outputs: int,
772+
) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]:
773+
graph_inputs = self.preprocess_inputs(onnx_inputs)
764774
for key, value in onnx_attributes.items():
765775
assert not isinstance(
766776
value, TorchScriptTensor

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5752,14 +5752,9 @@ def aten_nansum(
57525752
def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor:
57535753
"""narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)"""
57545754

5755-
if IsScalar(dim):
5756-
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
5757-
5758-
if IsScalar(start):
5759-
start = op.Reshape(start, op.Constant(value_ints=[-1]))
5760-
5761-
if IsScalar(length):
5762-
length = op.Reshape(length, op.Constant(value_ints=[-1]))
5755+
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
5756+
start = op.Reshape(start, op.Constant(value_ints=[-1]))
5757+
length = op.Reshape(length, op.Constant(value_ints=[-1]))
57635758

57645759
end = op.Add(start, length)
57655760
return op.Slice(self, start, end, dim)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,7 @@ def _where_input_wrangler(
13491349
.xfail(
13501350
variant_name="decimals_0",
13511351
reason="This variant does not accept decimals",
1352+
test_class_name="TestOutputConsistencyEager",
13521353
)
13531354
.xfail(
13541355
variant_name="decimals_3",

0 commit comments

Comments
 (0)