Skip to content

Commit b34fa31

Browse files
authored
Refactor algorithm unit tests (#213)
* refactor better structure the algorithm tests * add missing file
1 parent 8491504 commit b34fa31

17 files changed

+806
-555
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
2+
@testset "KLMinRepGradDescent" begin
3+
begin
4+
modelstats = normal_meanfield(Random.default_rng(), Float64)
5+
(; model, n_dims, μ_true, L_true) = modelstats
6+
7+
q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
8+
9+
@testset "basic n_samples=$(n_samples)" for n_samples in [1, 10]
10+
alg = KLMinRepGradDescent(AD; n_samples, operator=ClipScale())
11+
T = 1
12+
optimize(alg, T, model, q0; show_progress=PROGRESS)
13+
end
14+
15+
@testset "callback" begin
16+
alg = KLMinRepGradDescent(AD; operator=ClipScale())
17+
T = 10
18+
callback(; iteration, kwargs...) = (iteration_check=iteration,)
19+
_, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS)
20+
@test [i.iteration_check for i in info] == 1:T
21+
end
22+
23+
@testset "estimate_objective" begin
24+
alg = KLMinRepGradDescent(AD; operator=ClipScale())
25+
q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))
26+
27+
obj_est = estimate_objective(alg, q_true, model)
28+
@test isfinite(obj_est)
29+
30+
obj_est = estimate_objective(alg, q_true, model; n_samples=1)
31+
@test isfinite(obj_est)
32+
33+
obj_est = estimate_objective(alg, q_true, model; n_samples=3)
34+
@test isfinite(obj_est)
35+
36+
obj_est = estimate_objective(alg, q_true, model; n_samples=10^5)
37+
@test obj_est 0 atol=1e-2
38+
end
39+
40+
@testset "determinism" begin
41+
alg = KLMinRepGradDescent(AD; operator=ClipScale())
42+
43+
seed = (0x38bef07cf9cc549d)
44+
rng = StableRNG(seed)
45+
T = 10
46+
47+
q_out, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS)
48+
μ = q_out.location
49+
L = q_out.scale
50+
51+
rng_repl = StableRNG(seed)
52+
q_out, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS)
53+
μ_repl = q_out.location
54+
L_repl = q_out.scale
55+
@test μ == μ_repl
56+
@test L == L_repl
57+
end
58+
59+
@testset "warn MvLocationScale with IdentityOperator" begin
60+
@test_warn "IdentityOperator" begin
61+
alg′ = KLMinRepGradDescent(AD; operator=IdentityOperator())
62+
optimize(alg′, 1, model, q0; show_progress=false)
63+
end
64+
end
65+
66+
@testset "STL variance reduction" begin
67+
@testset for n_montecarlo in [1, 10]
68+
q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))
69+
params, re = Optimisers.destructure(q_true)
70+
obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy())
71+
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))
72+
73+
aux = (
74+
rng=Random.default_rng(),
75+
obj=obj,
76+
problem=model,
77+
restructure=re,
78+
q_stop=q_true,
79+
adtype=AD,
80+
)
81+
AdvancedVI._value_and_gradient!(
82+
AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux
83+
)
84+
grad = DiffResults.gradient(out)
85+
@test norm(grad) 0 atol = 1e-5
86+
end
87+
end
88+
end
89+
90+
@testset "type stability realtype=$(realtype)" for realtype in [Float32, Float64]
91+
modelstats = normal_meanfield(Random.default_rng(), realtype)
92+
(; model, n_dims, μ_true, L_true) = modelstats
93+
94+
T = 1
95+
alg = KLMinRepGradDescent(AD; n_samples=10, operator=ClipScale())
96+
q0 = MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
97+
98+
q_out, info, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)
99+
100+
@test eltype(q_out.location) == realtype
101+
@test eltype(q_out.scale) == realtype
102+
@test typeof(first(info).elbo) == realtype
103+
end
104+
105+
@testset "convergence $(entropy)" for entropy in
106+
[ClosedFormEntropy(), StickingTheLandingEntropy()]
107+
modelstats = normal_meanfield(Random.default_rng(), Float64)
108+
(; model, μ_true, L_true, is_meanfield) = modelstats
109+
110+
T = 1000
111+
optimizer = Descent(1e-3)
112+
alg = KLMinRepGradDescent(AD; entropy, optimizer, operator=ClipScale())
113+
q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
114+
115+
q_out, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)
116+
117+
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
118+
Δλ = sum(abs2, q_out.location - μ_true) + sum(abs2, q_out.scale - L_true)
119+
120+
@test Δλ Δλ0/2
121+
end
122+
123+
@testset "subsampling" begin
124+
n_data = 8
125+
126+
@testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4]
127+
modelstats = subsamplednormal(Random.default_rng(), n_data)
128+
(; model, n_dims, μ_true, L_true) = modelstats
129+
130+
L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
131+
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)
132+
operator = ClipScale()
133+
134+
subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
135+
alg = KLMinRepGradDescent(AD; n_samples=10, operator)
136+
alg_sub = KLMinRepGradDescent(AD; n_samples=10, subsampling, operator)
137+
138+
obj_full = estimate_objective(alg, q0, model; n_samples=10^5)
139+
obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5)
140+
@test obj_full obj_sub rtol=0.1
141+
end
142+
143+
@testset "determinism" begin
144+
seed = (0x38bef07cf9cc549d)
145+
rng = StableRNG(seed)
146+
147+
modelstats = subsamplednormal(Random.default_rng(), n_data)
148+
(; model, n_dims, μ_true, L_true) = modelstats
149+
150+
L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
151+
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)
152+
153+
T = 10
154+
batchsize = 3
155+
subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
156+
alg_sub = KLMinRepGradDescent(
157+
AD; n_samples=10, subsampling, operator=ClipScale()
158+
)
159+
160+
q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS)
161+
μ = q.location
162+
L = q.scale
163+
164+
rng_repl = StableRNG(seed)
165+
q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS)
166+
μ_repl = q.location
167+
L_repl = q.scale
168+
@test μ == μ_repl
169+
@test L == L_repl
170+
end
171+
172+
@testset "convergence" begin
173+
modelstats = subsamplednormal(Random.default_rng(), n_data)
174+
(; model, n_dims, μ_true, L_true) = modelstats
175+
176+
L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
177+
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)
178+
179+
T = 1000
180+
batchsize = 1
181+
optimizer = Descent(1e-3)
182+
subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
183+
alg_sub = KLMinRepGradDescent(
184+
AD; n_samples=10, subsampling, optimizer, operator=ClipScale()
185+
)
186+
187+
q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS)
188+
189+
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
190+
Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true)
191+
192+
@test Δλ Δλ0/2
193+
end
194+
end
195+
end
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
2+
@testset "KLMinRepGradDescent with Bijectors" begin
3+
begin
4+
modelstats = normallognormal_meanfield(Random.default_rng(), Float64)
5+
(; model, n_dims, μ_true, L_true) = modelstats
6+
7+
b = Bijectors.bijector(model)
8+
binv = inverse(b)
9+
10+
q0_unconstr = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
11+
q0 = Bijectors.transformed(q0_unconstr, binv)
12+
13+
@testset "estimate_objective" begin
14+
alg = KLMinRepGradDescent(AD; operator=ClipScale())
15+
q_true_unconstr = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))
16+
q_true = Bijectors.transformed(q_true_unconstr, binv)
17+
18+
obj_est = estimate_objective(alg, q_true, model)
19+
@test isfinite(obj_est)
20+
21+
obj_est = estimate_objective(alg, q_true, model; n_samples=1)
22+
@test isfinite(obj_est)
23+
24+
obj_est = estimate_objective(alg, q_true, model; n_samples=3)
25+
@test isfinite(obj_est)
26+
27+
obj_est = estimate_objective(alg, q_true, model; n_samples=10^5)
28+
@test obj_est 0 atol=1e-2
29+
end
30+
31+
@testset "determinism" begin
32+
alg = KLMinRepGradDescent(AD; operator=ClipScale())
33+
34+
seed = (0x38bef07cf9cc549d)
35+
rng = StableRNG(seed)
36+
T = 10
37+
38+
q_out, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS)
39+
μ = q_out.dist.location
40+
L = q_out.dist.scale
41+
42+
rng_repl = StableRNG(seed)
43+
q_out, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS)
44+
μ_repl = q_out.dist.location
45+
L_repl = q_out.dist.scale
46+
@test μ == μ_repl
47+
@test L == L_repl
48+
end
49+
50+
@testset "warn MvLocationScale with IdentityOperator" begin
51+
@test_warn "IdentityOperator" begin
52+
alg′ = KLMinRepGradDescent(AD; operator=IdentityOperator())
53+
optimize(alg′, 1, model, q0; show_progress=false)
54+
end
55+
end
56+
end
57+
58+
@testset "type stability realtype=$(realtype)" for realtype in [Float32, Float64]
59+
modelstats = normallognormal_meanfield(Random.default_rng(), realtype)
60+
(; model, n_dims, μ_true, L_true) = modelstats
61+
62+
T = 1
63+
alg = KLMinRepGradDescent(AD; n_samples=10, operator=ClipScale())
64+
q0_unconstr = MeanFieldGaussian(
65+
zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))
66+
)
67+
q0 = Bijectors.transformed(q0_unconstr, binv)
68+
69+
q_out, info, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)
70+
71+
@test eltype(q_out.dist.location) == realtype
72+
@test eltype(q_out.dist.scale) == realtype
73+
@test typeof(first(info).elbo) == realtype
74+
end
75+
76+
@testset "convergence $(entropy)" for entropy in
77+
[ClosedFormEntropy(), StickingTheLandingEntropy()]
78+
modelstats = normallognormal_meanfield(Random.default_rng(), Float64)
79+
(; model, μ_true, L_true, is_meanfield) = modelstats
80+
81+
T = 1000
82+
optimizer = Descent(1e-3)
83+
alg = KLMinRepGradDescent(AD; entropy, optimizer, operator=ClipScale())
84+
q0_unconstr = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
85+
q0 = Bijectors.transformed(q0_unconstr, binv)
86+
87+
q_out, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)
88+
89+
Δλ0 = sum(abs2, q0.dist.location - μ_true) + sum(abs2, q0.dist.scale - L_true)
90+
Δλ = sum(abs2, q_out.dist.location - μ_true) + sum(abs2, q_out.dist.scale - L_true)
91+
92+
@test Δλ Δλ0/2
93+
end
94+
end

0 commit comments

Comments
 (0)