Skip to content

Commit c10f5f5

Browse files
committed
update with the latest pytorch change
1 parent b88ab17 commit c10f5f5

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

onnxscript/ir/_schemas.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,7 @@
88
import logging
99
import types
1010
import typing
11-
from typing import (
12-
Any,
13-
Iterator,
14-
Mapping,
15-
Optional,
16-
Sequence,
17-
TypeVar,
18-
Union,
19-
)
11+
from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union
2012

2113
import onnx
2214

@@ -103,7 +95,7 @@ def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam:
10395

10496
@classmethod
10597
def any_value(cls, name: str, description: str = "") -> TypeConstraintParam:
106-
return cls(name, _ALL_VALUE_TYPES, description) # type: ignore
98+
return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type]
10799

108100

109101
@dataclasses.dataclass(frozen=True)
@@ -129,6 +121,8 @@ def has_default(self) -> bool:
129121

130122
@dataclasses.dataclass(frozen=True)
131123
class AttributeParameter:
124+
"""A parameter in the function signature that represents an ONNX attribute."""
125+
132126
name: str
133127
type: ir.AttributeType
134128
required: bool
@@ -147,7 +141,7 @@ def has_default(self) -> bool:
147141
def _get_type_from_str(
148142
type_str: str,
149143
) -> ir.TensorType | ir.SequenceType | ir.OptionalType:
150-
"""Converter a type_str from ONNX Opschema to ir.TypeProtocol.
144+
"""Converter a type_str from ONNX OpSchema to ir.TypeProtocol.
151145
152146
A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))".
153147
"""
@@ -180,14 +174,14 @@ def _convert_formal_parameter(
180174
param: onnx.defs.OpSchema.FormalParameter,
181175
type_constraints: Mapping[str, TypeConstraintParam],
182176
) -> Parameter:
183-
"""Convert a formal parameter from ONNX Opschema to Parameter."""
177+
"""Convert a formal parameter from ONNX OpSchema to Parameter."""
184178
if param.type_str in type_constraints:
185179
type_constraint = type_constraints[param.type_str]
186180
else:
187181
# param.type_str can be a plain type like 'int64'.
188182
type_constraint = TypeConstraintParam(
189183
name=param.name,
190-
allowed_types={_get_type_from_str(param.type_str)}, # type: ignore
184+
allowed_types={_get_type_from_str(param.type_str)},
191185
)
192186
return Parameter(
193187
name=param.name,
@@ -377,7 +371,7 @@ def __str__(self) -> str:
377371

378372
@classmethod
379373
def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature:
380-
"""Produce an OpSignature from an ONNX Opschema."""
374+
"""Produce an OpSignature from an ONNX OpSchema."""
381375
type_constraints = {
382376
constraint.type_param_str: TypeConstraintParam(
383377
name=constraint.type_param_str,
@@ -434,7 +428,7 @@ def from_function(
434428
# https://github.com/python/cpython/issues/102405
435429
type_hints = typing.get_type_hints(func)
436430

437-
params = []
431+
params: list[Parameter | AttributeParameter] = []
438432
# Create a mapping from type to a unique name
439433
type_constraints: dict[str, TypeConstraintParam] = {}
440434

@@ -445,7 +439,20 @@ def from_function(
445439
param.name,
446440
py_signature,
447441
)
448-
type_constraints[param.name] = TypeConstraintParam.any_value(f"T_{param.name}")
442+
type_constraint = TypeConstraintParam.any_value(f"T_{param.name}")
443+
type_constraints[param.name] = type_constraint
444+
params.append(
445+
Parameter(
446+
name=param.name,
447+
type_constraint=type_constraint,
448+
required=param.default is inspect.Parameter.empty,
449+
# TODO: Handle variadic
450+
variadic=False,
451+
default=param.default
452+
if param.default is not inspect.Parameter.empty
453+
else _EMPTY_DEFAULT,
454+
)
455+
)
449456
else:
450457
type_ = type_hints[param.name]
451458
if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED:
@@ -485,7 +492,7 @@ def from_function(
485492
type_constraints[type_constraint_name] = type_constraint
486493
# 4. Create Parameter
487494
params.append(
488-
Parameter( # type: ignore[arg-type]
495+
Parameter(
489496
name=param.name,
490497
type_constraint=type_constraint,
491498
required=param.default is inspect.Parameter.empty,

0 commit comments

Comments
 (0)