Skip to content

Commit 2f0be15

Browse files
authored
Add check_flatness keyword to HZ and format (#182)
* Add check_flatness keyword to HZ and format * undo format * Update issues.jl * Update issues.jl * Fix tests
1 parent 3259cd2 commit 2f0be15

File tree

4 files changed

+152
-12
lines changed

4 files changed

+152
-12
lines changed

src/hagerzhang.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Conjugate gradient line search implementation from:
9292
display::Int = 0
9393
mayterminate::Tm = Ref{Bool}(false)
9494
cache::Union{Nothing,LineSearchCache{T}} = nothing
95+
check_flatness::Bool = false
9596
end
9697
HagerZhang{T}(args...; kwargs...) where T = HagerZhang{T, Base.RefValue{Bool}}(args...; kwargs...)
9798

@@ -285,12 +286,13 @@ function (ls::HagerZhang)(ϕ, ϕdϕ,
285286
if display & LINESEARCH > 0
286287
println("Linesearch: secant succeeded")
287288
end
288-
if nextfloat(values[ia]) >= values[ib] && nextfloat(values[iA]) >= values[iB]
289+
if ls.check_flatness && nextfloat(values[ia]) >= values[ib] && nextfloat(values[iA]) >= values[iB]
289290
# It's so flat, secant didn't do anything useful, time to quit
290291
if display & LINESEARCH > 0
291292
println("Linesearch: secant suggests it's flat")
292293
end
293294
mayterminate[] = false # reset in case another initial guess is used next
295+
294296
return A, values[iA]
295297
end
296298
ia = iA

test/captured.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ end
3131
fdf = OnceDifferentiable(tc)
3232
hz = HagerZhang()
3333
α, val = hz(fdf.f, fdf.fdf, 1.0, fdf.fdf(0.0)...)
34-
@test_broken val <= minimum(tc)
34+
@test val <= minimum(tc)
3535
end

test/issues.jl

Lines changed: 146 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,173 @@ using LineSearches, LinearAlgebra, Test
33

44
A = randn(100, 100)
55
x0 = randn(100)
6-
b = A*x0
6+
b = A * x0
77

88
# Objective function and gradient
9-
f(x) = .5*norm(A*x - b)^2
10-
g!(gvec, x) = (gvec .= A'*(A*x-b))
9+
f(x) = 0.5 * norm(A * x - b)^2
10+
g!(gvec, x) = (gvec .= A' * (A * x - b))
1111
fg!(gvec, x) = (g!(gvec, x); return f(x))
1212

1313
# Init
14-
x = 1f1*randn(100)
14+
x = 1.0f1 * randn(100)
1515
gv = similar(x)
1616

1717
# Line search
18-
α0 = 1f-3
18+
α0 = 1.0f-3
1919
ϕ0 = fg!(gv, x)
20-
s = -1*gv
20+
s = -1 * gv
2121
dϕ0 = dot(gv, s)
2222
println(ϕ0, ", ", dϕ0)
2323

2424
# Univariate line search functions
25-
ϕ(α) = f(x .+ α.*s)
25+
ϕ(α) = f(x .+ α .* s)
2626
function (α)
27-
g!(gv, x .+ α.*s)
27+
g!(gv, x .+ α .* s)
2828
return dot(gv, s)
2929
end
3030
function ϕdϕ(α)
31-
phi = fg!(gv, x .+ α.*s)
31+
phi = fg!(gv, x .+ α .* s)
3232
dphi = dot(gv, s)
3333
return (phi, dphi)
3434
end
3535

3636
res = (StrongWolfe())(ϕ, dϕ, ϕdϕ, α0, ϕ0, dϕ0)
3737
@test res[2] > 0
3838
@test res[2] == ϕ(res[1])
39+
40+
@testset "HZ convergence issues" begin
41+
@testset "Flatness check issues" begin
42+
function prepare_test_case(; alphas, values, slopes)
43+
perm = sortperm(alphas)
44+
alphas = alphas[perm]
45+
push!(alphas, alphas[end] + 1)
46+
values = values[perm]
47+
push!(values, values[end])
48+
slopes = slopes[perm]
49+
push!(slopes, 0.0)
50+
return LineSearchTestCase(alphas, values, slopes)
51+
end
52+
53+
tc1 = prepare_test_case(;
54+
alphas = [0.0, 1.0, 5.0, 3.541670844449739],
55+
values = [
56+
3003.592409634743,
57+
2962.0378569864743,
58+
2891.4462095232184,
59+
3000.9760725116876,
60+
],
61+
slopes = [
62+
-22332.321416890798,
63+
-20423.214551925797,
64+
11718.185026267562,
65+
-22286.821227217057,
66+
],
67+
)
68+
69+
function tc_to_f(tc)
70+
function f(x)
71+
i = findfirst(u -> u > x, tc.alphas) - 1
72+
xk = tc.alphas[i]
73+
xkp1 = tc.alphas[i+1]
74+
dx = xkp1 - xk
75+
t = (x - xk) / dx
76+
h00t = 2t^3 - 3t^2 + 1
77+
h10t = t * (1 - t)^2
78+
h01t = t^2 * (3 - 2t)
79+
h11t = t^2 * (t - 1)
80+
val =
81+
h00t * tc.values[i] +
82+
h10t * dx * tc.slopes[i] +
83+
h01t * tc.values[i+1] +
84+
h11t * dx * tc.slopes[i+1]
85+
86+
return val
87+
end
88+
end
89+
function tc_to_fdf(tc)
90+
function fdf(x)
91+
i = findfirst(u -> u > x, tc.alphas) - 1
92+
xk = tc.alphas[i]
93+
xkp1 = tc.alphas[i+1]
94+
dx = xkp1 - xk
95+
t = (x - xk) / dx
96+
h00t = 2t^3 - 3t^2 + 1
97+
h10t = t * (1 - t)^2
98+
h01t = t^2 * (3 - 2t)
99+
h11t = t^2 * (t - 1)
100+
val =
101+
h00t * tc.values[i] +
102+
h10t * dx * tc.slopes[i] +
103+
h01t * tc.values[i+1] +
104+
h11t * dx * tc.slopes[i+1]
105+
106+
h00tp = 6t^2 - 6t
107+
h10tp = 3t^2 - 4t + 1
108+
h01tp = -6t^2 + 6 * t
109+
h11tp = 3t^2 - 2t
110+
slope =
111+
(
112+
h00tp * tc.values[i] +
113+
h10tp * dx * tc.slopes[i] +
114+
h01tp * tc.values[i+1] +
115+
h11tp * dx * tc.slopes[i+1]
116+
) / dx
117+
println(x, " ", val, " ", slope)
118+
return val, slope
119+
end
120+
end
121+
122+
function test_tc(tc, check_flatness)
123+
cache = LineSearchCache{Float64}()
124+
hz = HagerZhang(; cache, check_flatness)
125+
f = tc_to_f(tc)
126+
fdf = tc_to_fdf(tc)
127+
hz(f, fdf, 1.0, fdf(0.0)...), cache
128+
end
129+
130+
res, res_cache = test_tc(tc1, true)
131+
@show res
132+
@show res_cache
133+
@test_broken minimum(res_cache.values) == res[2]
134+
135+
res2, res_cache2 = test_tc(tc1, false)
136+
@test minimum(res_cache2.values) == res2[2]
137+
#=
138+
using AlgebraOfGraphics, CairoMakie
139+
draw(data((x=0.0:0.05:5.5, y=map(x->tc_to_f(tc1)(x), 0:0.05:5.5)))*mapping(:x,:y)*visual(Scatter)+
140+
data((alphas=res_cache.alphas, values=res_cache.values))*mapping(:alphas,:values)*visual(Scatter; color=:red))
141+
=#
142+
end
143+
144+
# should add as upstream
145+
#=
146+
@testset "from kbarros" begin
147+
# The minimizer is x0=[0, 2πn/100], with f(x0) = 1. Any integer n is fine.
148+
function f(x)
149+
return (x[1]^2 + 1) * (2 - cos(100*x[2]))
150+
end
151+
152+
using Optim
153+
154+
function test_converges(method)
155+
for i in 1:100
156+
r = randn(2)
157+
res = optimize(f, r, method)
158+
if Optim.converged(res) && minimum(res) > f([0,0]) + 1e-8
159+
println("""
160+
Incorrectly reported convergence after $(res.iterations) iterations
161+
Reached x = $(Optim.minimizer(res)) with f(x) = $(minimum(res))
162+
""")
163+
end
164+
end
165+
end
166+
167+
# Works successfully, no printed output
168+
test_converges(LBFGS(; linesearch=Optim.LineSearches.BackTracking(order=2)))
169+
170+
# Prints ~10 failures to converge (in 100 tries). Frequently fails after the
171+
# first line search.
172+
test_converges(ConjugateGradient(; linesearch=Optim.LineSearches.HagerZhang(check_flatness=false)))
173+
end
174+
=#
175+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ my_tests = [
1515
"alphacalc.jl",
1616
"arbitrary_precision.jl",
1717
"examples.jl",
18-
"captured.jl"
18+
"captured.jl",
19+
"issues.jl",
1920
]
2021

2122
mutable struct StateDummy

0 commit comments

Comments
 (0)