@@ -470,7 +470,7 @@ def prepare_module(self, module_op: Operation):
470470        ...
471471
472472    def  resolve_literal (
473-         self , gni : "GraphNodeImporter" , literal : Any 
473+         self , gni : "GraphNodeImporter" , literal : Any ,  info :  Optional [ InputInfo ] 
474474    ) ->  Optional [Value ]:
475475        """User overridable hook to resolve a literal value.""" 
476476        return  None 
@@ -1826,13 +1826,13 @@ def _convert_type(
18261826            name = op_name , results = [result_type ], operands = operands 
18271827        ).result 
18281828
1829-     def  _import_literal (self , py_value : Any ) ->  Value :
1829+     def  _import_literal (self , py_value : Any ,  info :  Optional [ InputInfo ]  =   None ) ->  Value :
18301830        orig_value  =  None 
18311831        if  isinstance (py_value , torch .Tensor ) and  py_value .dtype  ==  torch .bool :
18321832            orig_value  =  py_value 
18331833            py_value  =  py_value .to (torch .uint8 )
18341834        # Apply the conversion callback. 
1835-         user_value  =  self .fx_importer ._hooks .resolve_literal (self , py_value )
1835+         user_value  =  self .fx_importer ._hooks .resolve_literal (self , py_value ,  info )
18361836        if  user_value  is  not None :
18371837            assert  isinstance (user_value , Value )
18381838            if  orig_value  is  not None :
@@ -1866,7 +1866,7 @@ def _import_input(self, py_value: Any, info: InputInfo) -> Value:
18661866            raise  ValueError (
18671867                f"Cannot import { info .input_spec }  
18681868            )
1869-         return  self ._import_literal (py_value )
1869+         return  self ._import_literal (py_value ,  info )
18701870
18711871    def  _import_scalar_as_tensor (self , loc : Location , arg : NodeArgument ) ->  Value :
18721872        tensor_arg  =  torch .tensor (arg )
0 commit comments