Skip to content

Commit 9c72aa0

Browse files
authored
Add adabelief optimizer (#209)
1 parent 0d94e4e commit 9c72aa0

File tree

5 files changed

+255
-14
lines changed

5 files changed

+255
-14
lines changed

tests/test_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def build_lookahead(*a, **kw):
6666
(optim.AggMo, {'lr': 0.003}, 1800),
6767
(optim.SWATS, {'lr': 0.1, 'amsgrad': True, 'nesterov': True}, 900),
6868
(optim.Adafactor, {'lr': None, 'decay_rate': -0.3, 'beta1': 0.9}, 800),
69+
(optim.AdaBelief, {'lr': 1.0}, 500),
6970
]
7071

7172

tests/test_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def build_lookahead(*a, **kw):
6565

6666

6767
optimizers = [
68+
build_lookahead,
6869
optim.A2GradExp,
6970
optim.A2GradInc,
7071
optim.A2GradUni,
71-
build_lookahead,
7272
optim.AccSGD,
73+
optim.AdaBelief,
7374
optim.AdaBound,
7475
optim.AdaMod,
7576
optim.AdamP,

tests/test_optimizer_with_nn.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,32 @@ def build_lookahead(*a, **kw):
5151

5252

5353
optimizers = [
54+
(build_lookahead, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
5455
(optim.A2GradExp, {'lips': 1.0, 'beta': 1e-3}, 200),
5556
(optim.A2GradInc, {'lips': 1.0, 'beta': 1e-3}, 200),
5657
(optim.A2GradUni, {'lips': 1.0, 'beta': 1e-3}, 200),
58+
(optim.AccSGD, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
59+
(optim.AdaBelief, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
60+
(optim.AdaBound, {'lr': 1.5, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
61+
(optim.AdaMod, {'lr': 2.0, 'weight_decay': 1e-3}, 200),
5762
(optim.Adafactor, {'lr': None, 'weight_decay': 1e-3}, 200),
63+
(optim.AdamP, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
64+
(optim.AggMo, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
65+
(optim.DiffGrad, {'lr': 0.5, 'weight_decay': 1e-3}, 200),
66+
(optim.Lamb, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
5867
(optim.NovoGrad, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
5968
(optim.PID, {'lr': 0.01, 'weight_decay': 1e-3, 'momentum': 0.1}, 200),
60-
(optim.Lamb, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
61-
(optim.SGDW, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
62-
(optim.DiffGrad, {'lr': 0.5, 'weight_decay': 1e-3}, 200),
63-
(optim.AdaMod, {'lr': 2.0, 'weight_decay': 1e-3}, 200),
64-
(optim.AdaBound, {'lr': 1.5, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
65-
(optim.Yogi, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
66-
(optim.RAdam, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
67-
(optim.AccSGD, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
68-
(build_lookahead, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
69-
(optim.QHM, {'lr': 0.1, 'weight_decay': 1e-5, 'momentum': 0.2}, 200),
7069
(optim.QHAdam, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
70+
(optim.QHM, {'lr': 0.1, 'weight_decay': 1e-5, 'momentum': 0.2}, 200),
71+
(optim.RAdam, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
7172
(optim.Ranger, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
7273
(optim.RangerQH, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
7374
(optim.RangerVA, {'lr': 0.01, 'weight_decay': 1e-3}, 200),
74-
(optim.Shampoo, {'lr': 0.1, 'weight_decay': 1e-3, 'momentum': 0.8}, 200),
75-
(optim.AdamP, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
7675
(optim.SGDP, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
77-
(optim.AggMo, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
76+
(optim.SGDW, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
7877
(optim.SWATS, {'lr': 1.0, 'weight_decay': 1e-3}, 200),
78+
(optim.Shampoo, {'lr': 0.1, 'weight_decay': 1e-3, 'momentum': 0.8}, 200),
79+
(optim.Yogi, {'lr': 0.1, 'weight_decay': 1e-3}, 200),
7980
]
8081

8182

torch_optimizer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from .a2grad import A2GradExp, A2GradInc, A2GradUni
2323
from .accsgd import AccSGD
24+
from .adabelief import AdaBelief
2425
from .adabound import AdaBound
2526
from .adafactor import Adafactor
2627
from .adamod import AdaMod
@@ -41,6 +42,7 @@
4142
from .yogi import Yogi
4243

4344
__all__ = (
45+
'AdaBelief',
4446
'A2GradExp',
4547
'A2GradInc',
4648
'A2GradUni',
@@ -73,6 +75,7 @@
7375

7476

7577
_package_opts = [
78+
AdaBelief,
7679
AccSGD,
7780
AdaBound,
7881
AdaMod,

torch_optimizer/adabelief.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from .types import Betas2, OptFloat, OptLossClosure, Params
7+
8+
version_higher = torch.__version__ >= '1.5.0'
9+
10+
11+
__all__ = ('AdaBelief',)
12+
13+
14+
class AdaBelief(Optimizer):
15+
r"""Implements AdaBelief Optimizer Algorithm.
16+
It has been proposed in `AdaBelief Optimizer, adapting stepsizes by
17+
the belief in observed gradients`__.
18+
19+
Arguments:
20+
params: iterable of parameters to optimize or dicts defining
21+
parameter groups
22+
lr: learning rate (default: 1e-2)
23+
betas: coefficients used for computing
24+
running averages of gradient and its square (default: (0.9, 0.999))
25+
eps: term added to the denominator to improve
26+
numerical stability (default: 0.001)
27+
weight_decay: weight decay (L2 penalty) (default: 0)
28+
amsgrad: whether to use the AMSGrad variant of this
29+
algorithm from the paper `On the Convergence of Adam and Beyond`_
30+
(default: False)
31+
weight_decouple: If set as True, then the optimizer uses decoupled
32+
weight decay as in AdamW (default: False)
33+
fixed_decay : This is used when
34+
weight_decouple is set as True.
35+
When fixed_decay == True, the weight decay is performed as
36+
$W_{new} = W_{old} - W_{old} \times decay$.
37+
When fixed_decay == False, the weight decay is performed as
38+
$W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in
39+
this case, the weight decay ratio decreases with learning
40+
rate (lr). (default: False)
41+
rectify: (default: False) If set as True, then perform the rectified
42+
update similar to RAdam
43+
44+
Example:
45+
>>> import torch_optimizer as optim
46+
>>> optimizer = optim.AdaBelief(model.parameters(), lr=0.01)
47+
>>> optimizer.zero_grad()
48+
>>> loss_fn(model(input), target).backward()
49+
>>> optimizer.step()
50+
51+
__ https://arxiv.org/abs/2010.07468
52+
53+
Note:
54+
Reference code: https://github.com/juntang-zhuang/Adabelief-Optimizer
55+
"""
56+
57+
def __init__(
58+
self,
59+
params: Params,
60+
lr: float = 1e-3,
61+
betas: Betas2 = (0.9, 0.999),
62+
eps: float = 1e-3,
63+
weight_decay: float = 0,
64+
amsgrad: bool = False,
65+
weight_decouple: bool = False,
66+
fixed_decay: bool = False,
67+
rectify: bool = False,
68+
) -> None:
69+
if lr <= 0.0:
70+
raise ValueError('Invalid learning rate: {}'.format(lr))
71+
if eps < 0.0:
72+
raise ValueError('Invalid epsilon value: {}'.format(eps))
73+
if not 0.0 <= betas[0] < 1.0:
74+
raise ValueError(
75+
'Invalid beta parameter at index 0: {}'.format(betas[0])
76+
)
77+
if not 0.0 <= betas[1] < 1.0:
78+
raise ValueError(
79+
'Invalid beta parameter at index 1: {}'.format(betas[1])
80+
)
81+
if weight_decay < 0:
82+
raise ValueError(
83+
'Invalid weight_decay value: {}'.format(weight_decay)
84+
)
85+
defaults = dict(
86+
lr=lr,
87+
betas=betas,
88+
eps=eps,
89+
weight_decay=weight_decay,
90+
amsgrad=amsgrad,
91+
)
92+
super(AdaBelief, self).__init__(params, defaults)
93+
94+
self._weight_decouple = weight_decouple
95+
self._rectify = rectify
96+
self._fixed_decay = fixed_decay
97+
98+
def __setstate__(self, state):
99+
super(AdaBelief, self).__setstate__(state)
100+
for group in self.param_groups:
101+
group.setdefault('amsgrad', False)
102+
103+
def step(self, closure: OptLossClosure = None) -> OptFloat:
104+
r"""Performs a single optimization step.
105+
106+
Arguments:
107+
closure: A closure that reevaluates the model and returns the loss.
108+
"""
109+
loss = None
110+
if closure is not None:
111+
loss = closure()
112+
113+
for group in self.param_groups:
114+
for p in group['params']:
115+
if p.grad is None:
116+
continue
117+
grad = p.grad.data
118+
if grad.is_sparse:
119+
raise RuntimeError(
120+
'AdaBelief does not support sparse gradients, '
121+
'please consider SparseAdam instead'
122+
)
123+
amsgrad = group['amsgrad']
124+
125+
state = self.state[p]
126+
127+
beta1, beta2 = group['betas']
128+
129+
# State initialization
130+
if len(state) == 0:
131+
state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0
132+
state['step'] = 0
133+
# Exponential moving average of gradient values
134+
state['exp_avg'] = (
135+
torch.zeros_like(
136+
p.data, memory_format=torch.preserve_format
137+
)
138+
if version_higher
139+
else torch.zeros_like(p.data)
140+
)
141+
# Exponential moving average of squared gradient values
142+
state['exp_avg_var'] = (
143+
torch.zeros_like(
144+
p.data, memory_format=torch.preserve_format
145+
)
146+
if version_higher
147+
else torch.zeros_like(p.data)
148+
)
149+
if amsgrad:
150+
# Maintains max of all exp. moving avg. of
151+
# sq. grad. values
152+
state['max_exp_avg_var'] = (
153+
torch.zeros_like(
154+
p.data, memory_format=torch.preserve_format
155+
)
156+
if version_higher
157+
else torch.zeros_like(p.data)
158+
)
159+
160+
# get current state variable
161+
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
162+
163+
state['step'] += 1
164+
bias_correction1 = 1 - beta1 ** state['step']
165+
bias_correction2 = 1 - beta2 ** state['step']
166+
167+
# perform weight decay, check if decoupled weight decay
168+
if self._weight_decouple:
169+
if not self._fixed_decay:
170+
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
171+
else:
172+
p.data.mul_(1.0 - group['weight_decay'])
173+
else:
174+
if group['weight_decay'] != 0:
175+
grad.add_(p.data, alpha=group['weight_decay'])
176+
177+
# Update first and second moment running average
178+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
179+
grad_residual = grad - exp_avg
180+
exp_avg_var.mul_(beta2).addcmul_(
181+
grad_residual, grad_residual, value=1 - beta2
182+
)
183+
184+
if amsgrad:
185+
max_exp_avg_var = state['max_exp_avg_var']
186+
# Maintains the maximum of all 2nd moment running
187+
# avg. till now
188+
torch.max(
189+
max_exp_avg_var, exp_avg_var, out=max_exp_avg_var
190+
)
191+
192+
# Use the max. for normalizing running avg. of gradient
193+
denom = (
194+
max_exp_avg_var.add_(group['eps']).sqrt()
195+
/ math.sqrt(bias_correction2)
196+
).add_(group['eps'])
197+
else:
198+
denom = (
199+
exp_avg_var.add_(group['eps']).sqrt()
200+
/ math.sqrt(bias_correction2)
201+
).add_(group['eps'])
202+
203+
if not self._rectify:
204+
# Default update
205+
step_size = group['lr'] / bias_correction1
206+
p.data.addcdiv_(exp_avg, denom, value=-step_size)
207+
208+
else: # Rectified update
209+
# calculate rho_t
210+
state['rho_t'] = state['rho_inf'] - 2 * state[
211+
'step'
212+
] * beta2 ** state['step'] / (1.0 - beta2 ** state['step'])
213+
214+
if (
215+
state['rho_t'] > 4
216+
): # perform Adam style update if variance is small
217+
rho_inf, rho_t = state['rho_inf'], state['rho_t']
218+
rt = (
219+
(rho_t - 4.0)
220+
* (rho_t - 2.0)
221+
* rho_inf
222+
/ (rho_inf - 4.0)
223+
/ (rho_inf - 2.0)
224+
/ rho_t
225+
)
226+
rt = math.sqrt(rt)
227+
228+
step_size = rt * group['lr'] / bias_correction1
229+
230+
p.data.addcdiv_(-step_size, exp_avg, denom)
231+
232+
else: # perform SGD style update
233+
p.data.add_(-group['lr'], exp_avg)
234+
235+
return loss

0 commit comments

Comments
 (0)