1414# ============================================================================
1515r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
1616from typing import Any , TypeVar
17- from ..parameter import Parameter
17+ from typing_extensions import deprecated
18+ from ..parameter import Parameter , UninitializedParameter
1819from ..modules import Module
1920from ... import ops
2021
@@ -43,12 +44,6 @@ def _weight_norm(weight_v, weight_g, dim):
4344
4445
4546class WeightNorm :
46-
47- r"""
48- The 'WeightNorm' class implements weight normalization for neural network modules. It provides methods to compute normalized weights, apply weight normalization to a cell, wrap a function, and remove
49- weight bias from a cell. The class also contains an initializer for the name and dimension of the weight parameters, as well as a method to compute the weight using the normalized parameters. Additionally, it
50- includes a method to remove the weight bias and a wrapper function for transposing cell_id to cell.
51- """
5247 name : str
5348 dim : int
5449
@@ -60,63 +55,64 @@ def __init__(self, name: str, dim: int) -> None:
6055
6156 # TODO Make return type more specific
6257 def compute_weight (self , module : Module ) -> Any :
63- g = getattr (module , self .name + '_g' )
64- v = getattr (module , self .name + '_v' )
65- return Parameter ( _weight_norm (v , g , self .dim ) )
58+ g = getattr (module , self .name + "_g" )
59+ v = getattr (module , self .name + "_v" )
60+ return _weight_norm (v , g , self .dim )
6661
6762 @staticmethod
68- def apply (module , name : str , dim : int ) -> 'WeightNorm' :
69- for k , hook in module ._forward_pre_hooks .items ():
63+ @deprecated (
64+ "`torch.nn.utils.weight_norm` is deprecated "
65+ "in favor of `torch.nn.utils.parametrizations.weight_norm`." ,
66+ category = FutureWarning ,
67+ )
68+ def apply (module , name : str , dim : int ) -> "WeightNorm" :
69+ for hook in module ._forward_pre_hooks .values ():
7070 if isinstance (hook , WeightNorm ) and hook .name == name :
71- raise RuntimeError ("Cannot register two weight_norm hooks on "
72- "the same parameter {}" .format (name ))
71+ raise RuntimeError (
72+ f"Cannot register two weight_norm hooks on the same parameter { name } "
73+ )
7374
7475 if dim is None :
7576 dim = - 1
7677
7778 fn = WeightNorm (name , dim )
7879
7980 weight = getattr (module , name )
80- # if isinstance(weight, UninitializedParameter):
81- # raise ValueError(
82- # 'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
83- # 'Make sure to run the dummy forward before applying weight normalization')
81+ if isinstance (weight , UninitializedParameter ):
82+ raise ValueError (
83+ "The module passed to `WeightNorm` can't have uninitialized parameters. "
84+ "Make sure to run the dummy forward before applying weight normalization"
85+ )
8486 # remove w from parameter list
8587 del module ._parameters [name ]
8688
8789 # add g and v as new parameters and express w as g/||v|| * v
88- module .register_parameter (name + '_g' , Parameter (norm_except_dim (weight , 2 , dim )))
89- module .register_parameter (name + '_v' , Parameter (weight ))
90+ module .register_parameter (
91+ name + "_g" , Parameter (norm_except_dim (weight , 2 , dim ).data )
92+ )
93+ module .register_parameter (name + "_v" , Parameter (weight .data ))
9094 setattr (module , name , fn .compute_weight (module ))
9195
9296 # recompute weight before every forward()
9397 module .register_forward_pre_hook (fn )
9498
9599 return fn
96100
97- def wrapper_func (self , cell , func ):
98- r"""
99- wrapper_func where used to transpose cell_id to cell
100- """
101- def new_func (_ , inputs ):
102- nonlocal cell
103- return func (cell , inputs )
104- return new_func
105-
106101 def remove (self , module : Module ) -> None :
107102 weight = self .compute_weight (module )
108103 delattr (module , self .name )
109- del module ._parameters [self .name + '_g' ]
110- del module ._parameters [self .name + '_v' ]
111- setattr (module , self .name , weight )
104+ del module ._parameters [self .name + "_g" ]
105+ del module ._parameters [self .name + "_v" ]
106+ setattr (module , self .name , Parameter ( weight . data ) )
112107
113108 def __call__ (self , module : Module , inputs : Any ) -> None :
114109 setattr (module , self .name , self .compute_weight (module ))
115110
116111
117- T_module = TypeVar ('T_module' , bound = Module )
112+ T_module = TypeVar ("T_module" , bound = Module )
113+
118114
119- def weight_norm (module : T_module , name : str = ' weight' , dim : int = 0 ) -> T_module :
115+ def weight_norm (module : T_module , name : str = " weight" , dim : int = 0 ) -> T_module :
120116 r"""Apply weight normalization to a parameter in the given module.
121117
122118 .. math::
@@ -138,7 +134,7 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
138134
139135 .. warning::
140136
141- This function is deprecated. Use :func:`mindtorch .nn.utils.parametrizations.weight_norm`
137+ This function is deprecated. Use :func:`torch .nn.utils.parametrizations.weight_norm`
142138 which uses the modern parametrization API. The new ``weight_norm`` is compatible
143139 with ``state_dict`` generated from old ``weight_norm``.
144140
@@ -150,11 +146,11 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
150146 https://github.com/pytorch/pytorch/issues/102999
151147
152148 * To remove the weight normalization reparametrization, use
153- :func:`mindtorch .nn.utils.parametrize.remove_parametrizations`.
149+ :func:`torch .nn.utils.parametrize.remove_parametrizations`.
154150
155151 * The weight is no longer recomputed once at module forward; instead, it will
156152 be recomputed on every access. To restore the old behavior, use
157- :func:`mindtorch .nn.utils.parametrize.cached` before invoking the module
153+ :func:`torch .nn.utils.parametrize.cached` before invoking the module
158154 in question.
159155
160156 Args:
@@ -171,16 +167,17 @@ def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_modul
171167 >>> m
172168 Linear(in_features=20, out_features=40, bias=True)
173169 >>> m.weight_g.size()
174- mindtorch .Size([40, 1])
170+ torch .Size([40, 1])
175171 >>> m.weight_v.size()
176- mindtorch .Size([40, 20])
172+ torch .Size([40, 20])
177173
178174 """
179175 WeightNorm .apply (module , name , dim )
180176 return module
181177
182- def remove_weight_norm (module : T_module , name : str = 'weight' ) -> T_module :
183- r"""Removes the weight normalization reparameterization from a module.
178+
179+ def remove_weight_norm (module : T_module , name : str = "weight" ) -> T_module :
180+ r"""Remove the weight normalization reparameterization from a module.
184181
185182 Args:
186183 module (Module): containing module
@@ -196,5 +193,4 @@ def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
196193 del module ._forward_pre_hooks [k ]
197194 return module
198195
199- raise ValueError ("weight_norm of '{}' not found in {}"
200- .format (name , module ))
196+ raise ValueError (f"weight_norm of '{ name } ' not found in { module } " )
0 commit comments