1+ # pylint: disable=missing-docstring
12__all__ = [
2- " OnnxArgExtremumOld" ,
3- " OnnxArgExtremum" ,
3+ ' OnnxArgExtremumOld' ,
4+ ' OnnxArgExtremum' ,
45]
56
6- from typing import Optional
7-
87import torch
9- import torch .nn .functional as F
108from torch import nn
119
1210from onnx2torch .node_converters .registry import add_converter
2119DEFAULT_SELECT_LAST_INDEX = 0
2220
2321_TORCH_FUNCTION_FROM_ONNX_TYPE = {
24- " ArgMax" : torch .argmax ,
25- " ArgMin" : torch .argmin ,
22+ ' ArgMax' : torch .argmax ,
23+ ' ArgMin' : torch .argmin ,
2624}
2725
2826
29- class OnnxArgExtremumOld (nn .Module , OnnxToTorchModule ): # pylint: disable=missing-docstring
27+ class OnnxArgExtremumOld (nn .Module , OnnxToTorchModule ):
3028 def __init__ (self , operation_type : str , axis : int , keepdims : int ):
3129 super ().__init__ ()
3230 self .axis = axis
3331 self .keepdims = bool (keepdims )
3432 self .extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE [operation_type ]
3533
36- def forward (self , data : torch .Tensor ) -> torch .Tensor : # pylint: disable=missing-function-docstring
34+ def forward (self , data : torch .Tensor ) -> torch .Tensor :
3735 return self .extremum_function (data , dim = self .axis , keepdim = self .keepdims )
3836
3937
40- class OnnxArgExtremum (nn .Module , OnnxToTorchModule ): # pylint: disable=missing-class-docstring
38+ class OnnxArgExtremum (nn .Module , OnnxToTorchModule ):
4139 def __init__ (self , operation_type : str , axis : int , keepdims : int , select_last_index : int ):
4240 super ().__init__ ()
4341 self .axis = axis
4442 self .keepdims = bool (keepdims )
4543 self .select_last_index = bool (select_last_index )
4644 self .extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE [operation_type ]
4745
48- def forward (self , data : torch .Tensor ) -> torch .Tensor : # pylint: disable=missing-function-docstring
46+ def forward (self , data : torch .Tensor ) -> torch .Tensor :
4947 if self .select_last_index :
5048 # torch's argmax does not handle the select_last_index attribute from Onnx.
5149 # We flip the data, call the normal argmax, then map it back to the original
@@ -54,34 +52,36 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missin
5452 extremum_index_flipped = self .extremum_function (flipped , dim = self .axis , keepdim = self .keepdims )
5553 extremum_index_original = data .size (dim = self .axis ) - 1 - extremum_index_flipped
5654 return extremum_index_original
57- else :
58- return self .extremum_function (data , dim = self .axis , keepdim = self .keepdims )
55+
56+ return self .extremum_function (data , dim = self .axis , keepdim = self .keepdims )
5957
6058
61- @add_converter (operation_type = "ArgMax" , version = 12 )
62- @add_converter (operation_type = "ArgMax" , version = 13 )
63- @add_converter (operation_type = "ArgMin" , version = 12 )
64- @add_converter (operation_type = "ArgMin" , version = 13 )
65- def _ (node : OnnxNode , graph : OnnxGraph ) -> OperationConverterResult : # pylint: disable=unused-argument
59+ @add_converter (operation_type = 'ArgMax' , version = 12 )
60+ @add_converter (operation_type = 'ArgMax' , version = 13 )
61+ @add_converter (operation_type = 'ArgMin' , version = 12 )
62+ @add_converter (operation_type = 'ArgMin' , version = 13 )
63+ def _ (node : OnnxNode , graph : OnnxGraph ) -> OperationConverterResult :
64+ del graph
6665 return OperationConverterResult (
6766 torch_module = OnnxArgExtremum (
6867 operation_type = node .operation_type ,
69- axis = node .attributes .get (" axis" , DEFAULT_AXIS ),
70- keepdims = node .attributes .get (" keepdims" , DEFAULT_KEEPDIMS ),
71- select_last_index = node .attributes .get (" select_last_index" , DEFAULT_SELECT_LAST_INDEX ),
68+ axis = node .attributes .get (' axis' , DEFAULT_AXIS ),
69+ keepdims = node .attributes .get (' keepdims' , DEFAULT_KEEPDIMS ),
70+ select_last_index = node .attributes .get (' select_last_index' , DEFAULT_SELECT_LAST_INDEX ),
7271 ),
7372 onnx_mapping = onnx_mapping_from_node (node = node ),
7473 )
7574
7675
77- @add_converter (operation_type = "ArgMax" , version = 11 )
78- @add_converter (operation_type = "ArgMin" , version = 11 )
79- def _ (node : OnnxNode , graph : OnnxGraph ) -> OperationConverterResult : # pylint: disable=unused-argument
76+ @add_converter (operation_type = 'ArgMax' , version = 11 )
77+ @add_converter (operation_type = 'ArgMin' , version = 11 )
78+ def _ (node : OnnxNode , graph : OnnxGraph ) -> OperationConverterResult :
79+ del graph
8080 return OperationConverterResult (
8181 torch_module = OnnxArgExtremumOld (
8282 operation_type = node .operation_type ,
83- axis = node .attributes .get (" axis" , DEFAULT_AXIS ),
84- keepdims = node .attributes .get (" keepdims" , DEFAULT_KEEPDIMS ),
83+ axis = node .attributes .get (' axis' , DEFAULT_AXIS ),
84+ keepdims = node .attributes .get (' keepdims' , DEFAULT_KEEPDIMS ),
8585 ),
8686 onnx_mapping = onnx_mapping_from_node (node = node ),
8787 )
0 commit comments