Skip to content

Commit b790061

Browse files
authored
[FxImporter] Add InputInfo to Resolve Literal Hook (#3688)
1 parent 295bf41 commit b790061

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def prepare_module(self, module_op: Operation):
470470
...
471471

472472
def resolve_literal(
473-
self, gni: "GraphNodeImporter", literal: Any
473+
self, gni: "GraphNodeImporter", literal: Any, info: Optional[InputInfo]
474474
) -> Optional[Value]:
475475
"""User overridable hook to resolve a literal value."""
476476
return None
@@ -1826,13 +1826,13 @@ def _convert_type(
18261826
name=op_name, results=[result_type], operands=operands
18271827
).result
18281828

1829-
def _import_literal(self, py_value: Any) -> Value:
1829+
def _import_literal(self, py_value: Any, info: Optional[InputInfo] = None) -> Value:
18301830
orig_value = None
18311831
if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool:
18321832
orig_value = py_value
18331833
py_value = py_value.to(torch.uint8)
18341834
# Apply the conversion callback.
1835-
user_value = self.fx_importer._hooks.resolve_literal(self, py_value)
1835+
user_value = self.fx_importer._hooks.resolve_literal(self, py_value, info)
18361836
if user_value is not None:
18371837
assert isinstance(user_value, Value)
18381838
if orig_value is not None:
@@ -1866,7 +1866,7 @@ def _import_input(self, py_value: Any, info: InputInfo) -> Value:
18661866
raise ValueError(
18671867
f"Cannot import {info.input_spec} as a literal because it is mutable"
18681868
)
1869-
return self._import_literal(py_value)
1869+
return self._import_literal(py_value, info)
18701870

18711871
def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
18721872
tensor_arg = torch.tensor(arg)

0 commit comments

Comments
 (0)