88import logging
99import types
1010import typing
11- from typing import (
12- Any ,
13- Iterator ,
14- Mapping ,
15- Optional ,
16- Sequence ,
17- TypeVar ,
18- Union ,
19- )
11+ from typing import Any , Iterator , Mapping , Optional , Sequence , TypeVar , Union
2012
2113import onnx
2214
@@ -103,7 +95,7 @@ def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
10395
10496 @classmethod
10597 def any_value (cls , name : str , description : str = "" ) -> TypeConstraintParam :
106- return cls (name , _ALL_VALUE_TYPES , description ) # type: ignore
98+ return cls (name , _ALL_VALUE_TYPES , description ) # type: ignore[arg-type]
10799
108100
109101@dataclasses .dataclass (frozen = True )
@@ -129,6 +121,8 @@ def has_default(self) -> bool:
129121
130122@dataclasses .dataclass (frozen = True )
131123class AttributeParameter :
124+ """A parameter in the function signature that represents an ONNX attribute."""
125+
132126 name : str
133127 type : ir .AttributeType
134128 required : bool
@@ -147,7 +141,7 @@ def has_default(self) -> bool:
147141def _get_type_from_str (
148142 type_str : str ,
149143) -> ir .TensorType | ir .SequenceType | ir .OptionalType :
150- """Converter a type_str from ONNX Opschema to ir.TypeProtocol.
144+ """Converter a type_str from ONNX OpSchema to ir.TypeProtocol.
151145
152146 A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
153147 """
@@ -180,14 +174,14 @@ def _convert_formal_parameter(
180174 param : onnx .defs .OpSchema .FormalParameter ,
181175 type_constraints : Mapping [str , TypeConstraintParam ],
182176) -> Parameter :
183- """Convert a formal parameter from ONNX Opschema to Parameter."""
177+ """Convert a formal parameter from ONNX OpSchema to Parameter."""
184178 if param .type_str in type_constraints :
185179 type_constraint = type_constraints [param .type_str ]
186180 else :
187181 # param.type_str can be a plain type like 'int64'.
188182 type_constraint = TypeConstraintParam (
189183 name = param .name ,
190- allowed_types = {_get_type_from_str (param .type_str )}, # type: ignore
184+ allowed_types = {_get_type_from_str (param .type_str )},
191185 )
192186 return Parameter (
193187 name = param .name ,
@@ -377,7 +371,7 @@ def __str__(self) -> str:
377371
378372 @classmethod
379373 def from_op_schema (cls , op_schema : onnx .defs .OpSchema ) -> OpSignature :
380- """Produce an OpSignature from an ONNX Opschema ."""
374+ """Produce an OpSignature from an ONNX OpSchema ."""
381375 type_constraints = {
382376 constraint .type_param_str : TypeConstraintParam (
383377 name = constraint .type_param_str ,
@@ -434,7 +428,7 @@ def from_function(
434428 # https://github.com/python/cpython/issues/102405
435429 type_hints = typing .get_type_hints (func )
436430
437- params = []
431+ params : list [ Parameter | AttributeParameter ] = []
438432 # Create a mapping from type to a unique name
439433 type_constraints : dict [str , TypeConstraintParam ] = {}
440434
@@ -445,7 +439,20 @@ def from_function(
445439 param .name ,
446440 py_signature ,
447441 )
448- type_constraints [param .name ] = TypeConstraintParam .any_value (f"T_{ param .name } " )
442+ type_constraint = TypeConstraintParam .any_value (f"T_{ param .name } " )
443+ type_constraints [param .name ] = type_constraint
444+ params .append (
445+ Parameter (
446+ name = param .name ,
447+ type_constraint = type_constraint ,
448+ required = param .default is inspect .Parameter .empty ,
449+ # TODO: Handle variadic
450+ variadic = False ,
451+ default = param .default
452+ if param .default is not inspect .Parameter .empty
453+ else _EMPTY_DEFAULT ,
454+ )
455+ )
449456 else :
450457 type_ = type_hints [param .name ]
451458 if (attr_type := _get_attr_type (type_ )) != ir .AttributeType .UNDEFINED :
@@ -485,7 +492,7 @@ def from_function(
485492 type_constraints [type_constraint_name ] = type_constraint
486493 # 4. Create Parameter
487494 params .append (
488- Parameter ( # type: ignore[arg-type]
495+ Parameter (
489496 name = param .name ,
490497 type_constraint = type_constraint ,
491498 required = param .default is inspect .Parameter .empty ,
0 commit comments