Skip to content

Commit 35beb2a

Browse files
chore: linters
1 parent d4671a0 commit 35beb2a

File tree

4 files changed

+41
-58
lines changed

4 files changed

+41
-58
lines changed
Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1+
# pylint: disable=missing-docstring
12
__all__ = [
2-
"OnnxArgExtremumOld",
3-
"OnnxArgExtremum",
3+
'OnnxArgExtremumOld',
4+
'OnnxArgExtremum',
45
]
56

6-
from typing import Optional
7-
87
import torch
9-
import torch.nn.functional as F
108
from torch import nn
119

1210
from onnx2torch.node_converters.registry import add_converter
@@ -21,31 +19,31 @@
2119
DEFAULT_SELECT_LAST_INDEX = 0
2220

2321
_TORCH_FUNCTION_FROM_ONNX_TYPE = {
24-
"ArgMax": torch.argmax,
25-
"ArgMin": torch.argmin,
22+
'ArgMax': torch.argmax,
23+
'ArgMin': torch.argmin,
2624
}
2725

2826

29-
class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
27+
class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule):
3028
def __init__(self, operation_type: str, axis: int, keepdims: int):
3129
super().__init__()
3230
self.axis = axis
3331
self.keepdims = bool(keepdims)
3432
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]
3533

36-
def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
34+
def forward(self, data: torch.Tensor) -> torch.Tensor:
3735
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)
3836

3937

40-
class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
38+
class OnnxArgExtremum(nn.Module, OnnxToTorchModule):
4139
def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int):
4240
super().__init__()
4341
self.axis = axis
4442
self.keepdims = bool(keepdims)
4543
self.select_last_index = bool(select_last_index)
4644
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]
4745

48-
def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
46+
def forward(self, data: torch.Tensor) -> torch.Tensor:
4947
if self.select_last_index:
5048
# torch's argmax does not handle the select_last_index attribute from Onnx.
5149
# We flip the data, call the normal argmax, then map it back to the original
@@ -54,34 +52,36 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missin
5452
extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims)
5553
extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped
5654
return extremum_index_original
57-
else:
58-
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)
55+
56+
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)
5957

6058

61-
@add_converter(operation_type="ArgMax", version=12)
62-
@add_converter(operation_type="ArgMax", version=13)
63-
@add_converter(operation_type="ArgMin", version=12)
64-
@add_converter(operation_type="ArgMin", version=13)
65-
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
59+
@add_converter(operation_type='ArgMax', version=12)
60+
@add_converter(operation_type='ArgMax', version=13)
61+
@add_converter(operation_type='ArgMin', version=12)
62+
@add_converter(operation_type='ArgMin', version=13)
63+
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
64+
del graph
6665
return OperationConverterResult(
6766
torch_module=OnnxArgExtremum(
6867
operation_type=node.operation_type,
69-
axis=node.attributes.get("axis", DEFAULT_AXIS),
70-
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
71-
select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX),
68+
axis=node.attributes.get('axis', DEFAULT_AXIS),
69+
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
70+
select_last_index=node.attributes.get('select_last_index', DEFAULT_SELECT_LAST_INDEX),
7271
),
7372
onnx_mapping=onnx_mapping_from_node(node=node),
7473
)
7574

7675

77-
@add_converter(operation_type="ArgMax", version=11)
78-
@add_converter(operation_type="ArgMin", version=11)
79-
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
76+
@add_converter(operation_type='ArgMax', version=11)
77+
@add_converter(operation_type='ArgMin', version=11)
78+
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
79+
del graph
8080
return OperationConverterResult(
8181
torch_module=OnnxArgExtremumOld(
8282
operation_type=node.operation_type,
83-
axis=node.attributes.get("axis", DEFAULT_AXIS),
84-
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
83+
axis=node.attributes.get('axis', DEFAULT_AXIS),
84+
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
8585
),
8686
onnx_mapping=onnx_mapping_from_node(node=node),
8787
)

onnx2torch/utils/custom_export_to_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def export(cls, forward_function: Callable, *args) -> Any:
5757
return cls.apply(*args)
5858

5959
@staticmethod
60-
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
60+
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument, arguments-differ
6161
"""Applies custom forward function."""
6262
if CustomExportToOnnx._NEXT_FORWARD_FUNCTION is None:
6363
raise RuntimeError('Forward function is not set')

tests/node_converters/arg_extrema_test.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
# pylint: disable=missing-docstring
12
from pathlib import Path
23

34
import numpy as np
45
import onnx
5-
from onnx.helper import make_tensor_value_info
66
import pytest
77
import torch
8+
from onnx.helper import make_tensor_value_info
89

910
from tests.utils.common import check_onnx_model
1011
from tests.utils.common import make_model_from_nodes
@@ -51,7 +52,7 @@
5152
"select_last_index",
5253
(0, 1),
5354
)
54-
def test_arg_max_arg_min( # pylint: disable=missing-function-docstring
55+
def test_arg_max_arg_min(
5556
op_type: str,
5657
opset_version: int,
5758
dims: int,
@@ -95,7 +96,7 @@ class ArgMaxModel(torch.nn.Module):
9596
def __init__(self, axis: int, keepdims: bool):
9697
super().__init__()
9798
self.axis = axis
98-
self.keepdims = bool(keepdims)
99+
self.keepdims = keepdims
99100

100101
def forward(self, data: torch.Tensor) -> torch.Tensor:
101102
return torch.argmax(data, dim=self.axis, keepdim=self.keepdims)
@@ -105,29 +106,16 @@ class ArgMinModel(torch.nn.Module):
105106
def __init__(self, axis: int, keepdims: bool):
106107
super().__init__()
107108
self.axis = axis
108-
self.keepdims = bool(keepdims)
109+
self.keepdims = keepdims
109110

110111
def forward(self, data: torch.Tensor) -> torch.Tensor:
111112
return torch.argmin(data, dim=self.axis, keepdim=self.keepdims)
112113

113114

115+
@pytest.mark.parametrize("op_type", ["ArgMax", "ArgMin"])
116+
@pytest.mark.parametrize("opset_version", [11, 12, 13])
114117
@pytest.mark.parametrize(
115-
"op_type",
116-
(
117-
"ArgMax",
118-
"ArgMin",
119-
),
120-
)
121-
@pytest.mark.parametrize(
122-
"opset_version",
123-
(
124-
11,
125-
12,
126-
13,
127-
),
128-
)
129-
@pytest.mark.parametrize(
130-
"dims,axis",
118+
"dims, axis",
131119
(
132120
(1, 0),
133121
(2, 0),
@@ -141,19 +129,13 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
141129
(4, 3),
142130
),
143131
)
144-
@pytest.mark.parametrize(
145-
"keepdims",
146-
(
147-
0,
148-
1,
149-
),
150-
)
132+
@pytest.mark.parametrize("keepdims", [True, False])
151133
def test_start_from_torch_module(
152134
op_type: str,
153135
opset_version: int,
154136
dims: int,
155137
axis: int,
156-
keepdims: int,
138+
keepdims: bool,
157139
tmp_path: Path,
158140
) -> None:
159141
"""
@@ -179,7 +161,7 @@ def test_start_from_torch_module(
179161
input_names=input_names,
180162
output_names=output_names,
181163
do_constant_folding=False,
182-
training=torch._C._onnx.TrainingMode.TRAINING,
164+
opset_version=opset_version,
183165
)
184166

185167
# load the exported onnx file

tests/node_converters/conv_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from itertools import chain
22
from itertools import product
3+
from typing import Literal
34
from typing import Tuple
45

56
import numpy as np
@@ -10,7 +11,7 @@
1011

1112

1213
def _test_conv(
13-
op_type: str,
14+
op_type: Literal['Conv', 'ConvTranspose'],
1415
in_channels: int,
1516
out_channels: int,
1617
kernel_shape: Tuple[int, int],
@@ -23,7 +24,7 @@ def _test_conv(
2324
x = np.random.uniform(low=-1.0, high=1.0, size=x_shape).astype(np.float32)
2425
if op_type == 'Conv':
2526
weights_shape = (out_channels, in_channels // group) + kernel_shape
26-
elif op_type == 'ConvTranspose':
27+
else: # ConvTranspose
2728
weights_shape = (in_channels, out_channels // group) + kernel_shape
2829
weights = np.random.uniform(low=-1.0, high=1.0, size=weights_shape).astype(np.float32)
2930

0 commit comments

Comments
 (0)