Skip to content

Commit b82331e

Browse files
committed
fix bug in calculating ELBO with subsampling
1 parent dadd318 commit b82331e

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

src/algorithms/paramspacesgd/subsampledobj.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ function estimate_objective(
4949
)
5050
(; objective, subsampling) = subobj
5151
sub_st = init(rng, subsampling)
52-
return mean(1:length(subsampling)) do _
52+
return mapreduce(+, 1:length(subsampling)) do _
5353
batch, sub_st, _ = step(rng, subsampling, sub_st)
5454
prob_sub = subsample(prob, batch)
5555
q_sub = subsample(q, batch)
56-
estimate_objective(rng, objective, q_sub, prob_sub; kwargs...)
56+
estimate_objective(rng, objective, q_sub, prob_sub; kwargs...) / length(subsampling)
5757
end
5858
end
5959

test/algorithms/paramspacesgd/subsampledobj.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ end
6363

6464
@testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4]
6565
sub_obj′ = SubsampledObjective(full_obj, batchsize, 1:n_data)
66-
full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^6)
67-
sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^6)
68-
@info("", full_objval, sub_objval)
66+
full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^8)
67+
sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^8)
6968
@test full_objval sub_objval rtol=0.1
7069
end
7170

0 commit comments

Comments
 (0)