|
| 1 | + |
| 2 | +@testset "KLMinRepGradDescent" begin |
| 3 | + begin |
| 4 | + modelstats = normal_meanfield(Random.default_rng(), Float64) |
| 5 | + (; model, n_dims, μ_true, L_true) = modelstats |
| 6 | + |
| 7 | + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) |
| 8 | + |
| 9 | + @testset "basic n_samples=$(n_samples)" for n_samples in [1, 10] |
| 10 | + alg = KLMinRepGradDescent(AD; n_samples, operator=ClipScale()) |
| 11 | + T = 1 |
| 12 | + optimize(alg, T, model, q0; show_progress=PROGRESS) |
| 13 | + end |
| 14 | + |
| 15 | + @testset "callback" begin |
| 16 | + alg = KLMinRepGradDescent(AD; operator=ClipScale()) |
| 17 | + T = 10 |
| 18 | + callback(; iteration, kwargs...) = (iteration_check=iteration,) |
| 19 | + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) |
| 20 | + @test [i.iteration_check for i in info] == 1:T |
| 21 | + end |
| 22 | + |
| 23 | + @testset "estimate_objective" begin |
| 24 | + alg = KLMinRepGradDescent(AD; operator=ClipScale()) |
| 25 | + q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true)) |
| 26 | + |
| 27 | + obj_est = estimate_objective(alg, q_true, model) |
| 28 | + @test isfinite(obj_est) |
| 29 | + |
| 30 | + obj_est = estimate_objective(alg, q_true, model; n_samples=1) |
| 31 | + @test isfinite(obj_est) |
| 32 | + |
| 33 | + obj_est = estimate_objective(alg, q_true, model; n_samples=3) |
| 34 | + @test isfinite(obj_est) |
| 35 | + |
| 36 | + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) |
| 37 | + @test obj_est ≈ 0 atol=1e-2 |
| 38 | + end |
| 39 | + |
| 40 | + @testset "determinism" begin |
| 41 | + alg = KLMinRepGradDescent(AD; operator=ClipScale()) |
| 42 | + |
| 43 | + seed = (0x38bef07cf9cc549d) |
| 44 | + rng = StableRNG(seed) |
| 45 | + T = 10 |
| 46 | + |
| 47 | + q_out, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) |
| 48 | + μ = q_out.location |
| 49 | + L = q_out.scale |
| 50 | + |
| 51 | + rng_repl = StableRNG(seed) |
| 52 | + q_out, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) |
| 53 | + μ_repl = q_out.location |
| 54 | + L_repl = q_out.scale |
| 55 | + @test μ == μ_repl |
| 56 | + @test L == L_repl |
| 57 | + end |
| 58 | + |
| 59 | + @testset "warn MvLocationScale with IdentityOperator" begin |
| 60 | + @test_warn "IdentityOperator" begin |
| 61 | + alg′ = KLMinRepGradDescent(AD; operator=IdentityOperator()) |
| 62 | + optimize(alg′, 1, model, q0; show_progress=false) |
| 63 | + end |
| 64 | + end |
| 65 | + |
| 66 | + @testset "STL variance reduction" begin |
| 67 | + @testset for n_montecarlo in [1, 10] |
| 68 | + q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true)) |
| 69 | + params, re = Optimisers.destructure(q_true) |
| 70 | + obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()) |
| 71 | + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) |
| 72 | + |
| 73 | + aux = ( |
| 74 | + rng=Random.default_rng(), |
| 75 | + obj=obj, |
| 76 | + problem=model, |
| 77 | + restructure=re, |
| 78 | + q_stop=q_true, |
| 79 | + adtype=AD, |
| 80 | + ) |
| 81 | + AdvancedVI._value_and_gradient!( |
| 82 | + AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux |
| 83 | + ) |
| 84 | + grad = DiffResults.gradient(out) |
| 85 | + @test norm(grad) ≈ 0 atol = 1e-5 |
| 86 | + end |
| 87 | + end |
| 88 | + end |
| 89 | + |
| 90 | + @testset "type stability realtype=$(realtype)" for realtype in [Float32, Float64] |
| 91 | + modelstats = normal_meanfield(Random.default_rng(), realtype) |
| 92 | + (; model, n_dims, μ_true, L_true) = modelstats |
| 93 | + |
| 94 | + T = 1 |
| 95 | + alg = KLMinRepGradDescent(AD; n_samples=10, operator=ClipScale()) |
| 96 | + q0 = MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) |
| 97 | + |
| 98 | + q_out, info, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) |
| 99 | + |
| 100 | + @test eltype(q_out.location) == realtype |
| 101 | + @test eltype(q_out.scale) == realtype |
| 102 | + @test typeof(first(info).elbo) == realtype |
| 103 | + end |
| 104 | + |
| 105 | + @testset "convergence $(entropy)" for entropy in |
| 106 | + [ClosedFormEntropy(), StickingTheLandingEntropy()] |
| 107 | + modelstats = normal_meanfield(Random.default_rng(), Float64) |
| 108 | + (; model, μ_true, L_true, is_meanfield) = modelstats |
| 109 | + |
| 110 | + T = 1000 |
| 111 | + optimizer = Descent(1e-3) |
| 112 | + alg = KLMinRepGradDescent(AD; entropy, optimizer, operator=ClipScale()) |
| 113 | + q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) |
| 114 | + |
| 115 | + q_out, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) |
| 116 | + |
| 117 | + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) |
| 118 | + Δλ = sum(abs2, q_out.location - μ_true) + sum(abs2, q_out.scale - L_true) |
| 119 | + |
| 120 | + @test Δλ ≤ Δλ0/2 |
| 121 | + end |
| 122 | + |
| 123 | + @testset "subsampling" begin |
| 124 | + n_data = 8 |
| 125 | + |
| 126 | + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] |
| 127 | + modelstats = subsamplednormal(Random.default_rng(), n_data) |
| 128 | + (; model, n_dims, μ_true, L_true) = modelstats |
| 129 | + |
| 130 | + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) |
| 131 | + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) |
| 132 | + operator = ClipScale() |
| 133 | + |
| 134 | + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) |
| 135 | + alg = KLMinRepGradDescent(AD; n_samples=10, operator) |
| 136 | + alg_sub = KLMinRepGradDescent(AD; n_samples=10, subsampling, operator) |
| 137 | + |
| 138 | + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) |
| 139 | + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) |
| 140 | + @test obj_full ≈ obj_sub rtol=0.1 |
| 141 | + end |
| 142 | + |
| 143 | + @testset "determinism" begin |
| 144 | + seed = (0x38bef07cf9cc549d) |
| 145 | + rng = StableRNG(seed) |
| 146 | + |
| 147 | + modelstats = subsamplednormal(Random.default_rng(), n_data) |
| 148 | + (; model, n_dims, μ_true, L_true) = modelstats |
| 149 | + |
| 150 | + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) |
| 151 | + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) |
| 152 | + |
| 153 | + T = 10 |
| 154 | + batchsize = 3 |
| 155 | + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) |
| 156 | + alg_sub = KLMinRepGradDescent( |
| 157 | + AD; n_samples=10, subsampling, operator=ClipScale() |
| 158 | + ) |
| 159 | + |
| 160 | + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) |
| 161 | + μ = q.location |
| 162 | + L = q.scale |
| 163 | + |
| 164 | + rng_repl = StableRNG(seed) |
| 165 | + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) |
| 166 | + μ_repl = q.location |
| 167 | + L_repl = q.scale |
| 168 | + @test μ == μ_repl |
| 169 | + @test L == L_repl |
| 170 | + end |
| 171 | + |
| 172 | + @testset "convergence" begin |
| 173 | + modelstats = subsamplednormal(Random.default_rng(), n_data) |
| 174 | + (; model, n_dims, μ_true, L_true) = modelstats |
| 175 | + |
| 176 | + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) |
| 177 | + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) |
| 178 | + |
| 179 | + T = 1000 |
| 180 | + batchsize = 1 |
| 181 | + optimizer = Descent(1e-3) |
| 182 | + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) |
| 183 | + alg_sub = KLMinRepGradDescent( |
| 184 | + AD; n_samples=10, subsampling, optimizer, operator=ClipScale() |
| 185 | + ) |
| 186 | + |
| 187 | + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) |
| 188 | + |
| 189 | + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) |
| 190 | + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) |
| 191 | + |
| 192 | + @test Δλ ≤ Δλ0/2 |
| 193 | + end |
| 194 | + end |
| 195 | +end |
0 commit comments