Skip to content

Commit e332d8c

Browse files
committed
remove the type ParamSpaceSGD
1 parent 7cd9e56 commit e332d8c

File tree

4 files changed

+167
-79
lines changed

4 files changed

+167
-79
lines changed

src/AdvancedVI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ include("optimize.jl")
279279

280280
## Parameter Space SGD
281281
include("algorithms/paramspacesgd/abstractobjective.jl")
282-
include("algorithms/paramspacesgd/paramspacesgd.jl")
283282

284283
export ParamSpaceSGD
285284

@@ -319,6 +318,7 @@ export RepGradELBO,
319318
SubsampledObjective
320319

321320
include("algorithms/paramspacesgd/constructors.jl")
321+
include("algorithms/paramspacesgd/paramspacesgd.jl")
322322

323323
export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI
324324

src/algorithms/paramspacesgd/constructors.jl

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,53 @@ KL divergence minimization by running stochastic gradient descent with the repar
1818
- `operator::AbstractOperator`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
1919
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)
2020
21+
# Output
22+
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.
23+
24+
# Callback
25+
The callback function `callback` has a signature of
26+
27+
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)
28+
29+
The arguments are as follows:
30+
- `rng`: Random number generator internally used by the algorithm.
31+
- `iteration`: The index of the current iteration.
32+
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
33+
- `params`: Current variational parameters.
34+
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
35+
- `gradient`: The estimated (possibly stochastic) gradient.
36+
2137
# Requirements
2238
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
2339
- The variational approximation ``q_{\\lambda}`` implements `rand`.
2440
- The target distribution and the variational approximation have the same support.
2541
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
2642
- Additonal requirements on `q` may apply depending on the choice of `entropy`.
2743
"""
44+
struct KLMinRepGradDescent{
45+
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
46+
AD<:ADTypes.AbstractADType,
47+
Opt<:Optimisers.AbstractRule,
48+
Avg<:AbstractAverager,
49+
Op<:AbstractOperator,
50+
} <: AbstractVariationalAlgorithm
51+
objective::Obj
52+
adtype::AD
53+
optimizer::Opt
54+
averager::Avg
55+
operator::Op
56+
end
57+
58+
struct KLMinRepGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
59+
prob::P
60+
q::Q
61+
iteration::Int
62+
grad_buf::GradBuf
63+
opt_st::OptSt
64+
obj_st::ObjSt
65+
avg_st::AvgSt
66+
end
67+
2868
function KLMinRepGradDescent(
2969
adtype::ADTypes.AbstractADType;
3070
entropy::Union{<:ClosedFormEntropy,<:StickingTheLandingEntropy,<:MonteCarloEntropy}=ClosedFormEntropy(),
@@ -39,7 +79,11 @@ function KLMinRepGradDescent(
3979
else
4080
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy), subsampling)
4181
end
42-
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
82+
return KLMinRepGradDescent{
83+
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
84+
}(
85+
objective, adtype, optimizer, averager, operator
86+
)
4387
end
4488

4589
const ADVI = KLMinRepGradDescent
@@ -63,12 +107,52 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
63107
- `averager::AbstractAverager`: Parameter averaging strategy. (default: `PolynomialAveraging()`)
64108
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)
65109
110+
# Output
111+
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.
112+
113+
# Callback
114+
The callback function `callback` has a signature of
115+
116+
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)
117+
118+
The arguments are as follows:
119+
- `rng`: Random number generator internally used by the algorithm.
120+
- `iteration`: The index of the current iteration.
121+
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
122+
- `params`: Current variational parameters.
123+
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
124+
- `gradient`: The estimated (possibly stochastic) gradient.
125+
66126
# Requirements
67127
- The variational family is `MvLocationScale`.
68128
- The target distribution and the variational approximation have the same support.
69129
- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend.
70130
- Additonal requirements on `q` may apply depending on the choice of `entropy_zerograd`.
71131
"""
132+
struct KLMinRepGradProxDescent{
133+
Obj<:Union{<:RepGradELBO,<:SubsampledObjective},
134+
AD<:ADTypes.AbstractADType,
135+
Opt<:Optimisers.AbstractRule,
136+
Avg<:AbstractAverager,
137+
Op<:ProximalLocationScaleEntropy,
138+
} <: AbstractVariationalAlgorithm
139+
objective::Obj
140+
adtype::AD
141+
optimizer::Opt
142+
averager::Avg
143+
operator::Op
144+
end
145+
146+
struct KLMinRepGradProxDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
147+
prob::P
148+
q::Q
149+
iteration::Int
150+
grad_buf::GradBuf
151+
opt_st::OptSt
152+
obj_st::ObjSt
153+
avg_st::AvgSt
154+
end
155+
72156
function KLMinRepGradProxDescent(
73157
adtype::ADTypes.AbstractADType;
74158
entropy_zerograd::Union{
@@ -85,7 +169,11 @@ function KLMinRepGradProxDescent(
85169
else
86170
SubsampledObjective(RepGradELBO(n_samples; entropy=entropy_zerograd), subsampling)
87171
end
88-
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
172+
return KLMinRepGradProxDescent{
173+
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
174+
}(
175+
objective, adtype, optimizer, averager, operator
176+
)
89177
end
90178

91179
"""
@@ -106,15 +194,55 @@ KL divergence minimization by running stochastic gradient descent with the score
106194
- `operator::Union{<:IdentityOperator, <:ClipScale}`: Operator to be applied after each gradient descent step. (default: `IdentityOperator()`)
107195
- `subsampling::Union{<:Nothing,<:AbstractSubsampling}`: Data point subsampling strategy. If `nothing`, subsampling is not used. (default: `nothing`)
108196
197+
# Output
198+
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.
199+
200+
# Callback
201+
The callback function `callback` has a signature of
202+
203+
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)
204+
205+
The arguments are as follows:
206+
- `rng`: Random number generator internally used by the algorithm.
207+
- `iteration`: The index of the current iteration.
208+
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
209+
- `params`: Current variational parameters.
210+
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
211+
- `gradient`: The estimated (possibly stochastic) gradient.
212+
109213
# Requirements
110214
- The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. This requires the variational approximation to be marked as a functor through `Functors.@functor`.
111215
- The variational approximation ``q_{\\lambda}`` implements `rand`.
112216
- The variational approximation ``q_{\\lambda}`` implements `logpdf(q, x)`, which should also be differentiable with respect to `x`.
113217
- The target distribution and the variational approximation have the same support.
114218
"""
219+
struct KLMinScoreGradDescent{
220+
Obj<:Union{<:ScoreGradELBO,<:SubsampledObjective},
221+
AD<:ADTypes.AbstractADType,
222+
Opt<:Optimisers.AbstractRule,
223+
Avg<:AbstractAverager,
224+
Op<:AbstractOperator,
225+
} <: AbstractVariationalAlgorithm
226+
objective::Obj
227+
adtype::AD
228+
optimizer::Opt
229+
averager::Avg
230+
operator::Op
231+
end
232+
233+
struct KLMinScoreGradDescentState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
234+
prob::P
235+
q::Q
236+
iteration::Int
237+
grad_buf::GradBuf
238+
opt_st::OptSt
239+
obj_st::ObjSt
240+
avg_st::AvgSt
241+
end
242+
115243
function KLMinScoreGradDescent(
116244
adtype::ADTypes.AbstractADType;
117-
optimizer::Union{<:Descent,<:DoG,<:DoWG}=DoWG(),
245+
optimizer::Optimisers.AbstractRule=DoWG(),
118246
n_samples::Int=1,
119247
averager::AbstractAverager=PolynomialAveraging(),
120248
operator::AbstractOperator=IdentityOperator(),
@@ -125,7 +253,11 @@ function KLMinScoreGradDescent(
125253
else
126254
SubsampledObjective(ScoreGradELBO(n_samples), subsampling)
127255
end
128-
return ParamSpaceSGD(objective, adtype, optimizer, averager, operator)
256+
return KLMinScoreGradDescent{
257+
typeof(objective),typeof(adtype),typeof(optimizer),typeof(averager),typeof(operator)
258+
}(
259+
objective, adtype, optimizer, averager, operator
260+
)
129261
end
130262

131263
const BBVI = KLMinScoreGradDescent

src/algorithms/paramspacesgd/paramspacesgd.jl

Lines changed: 29 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,9 @@
11

2-
"""
3-
ParamSpaceSGD(
4-
objective::AbstractVariationalObjective,
5-
adtype::ADTypes.AbstractADType,
6-
optimizer::Optimisers.AbstractRule,
7-
averager::AbstractAverager,
8-
operator::AbstractOperator,
9-
)
10-
11-
This algorithm applies stochastic gradient descent (SGD) to the variational `objective` over the (Euclidean) space of variational parameters.
12-
13-
The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
14-
This requires the variational approximation to be marked as a functor through `Functors.@functor`.
15-
16-
!!! note
17-
Different objective may impose different requirements on `adtype`, variational family, `optimizer`, and `operator`. It is therefore important to check the documentation corresponding to each specific objective. Essentially, each objective should be thought as forming its own unique algorithm.
18-
19-
# Arguments
20-
- `objective`: Variational Objective.
21-
- `adtype`: Automatic differentiation backend.
22-
- `optimizer`: Optimizer used for inference.
23-
- `averager` : Parameter averaging strategy.
24-
- `operator` : Operator applied to the parameters after each optimization step.
25-
26-
# Output
27-
- `q_averaged`: The variational approximation formed from the averaged SGD iterates.
28-
29-
# Callback
30-
The callback function `callback` has a signature of
31-
32-
callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)
33-
34-
The arguments are as follows:
35-
- `rng`: Random number generator internally used by the algorithm.
36-
- `iteration`: The index of the current iteration.
37-
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
38-
- `params`: Current variational parameters.
39-
- `averaged_params`: Variational parameters averaged according to the averaging strategy.
40-
- `gradient`: The estimated (possibly stochastic) gradient.
41-
42-
"""
43-
struct ParamSpaceSGD{
44-
Obj<:AbstractVariationalObjective,
45-
AD<:ADTypes.AbstractADType,
46-
Opt<:Optimisers.AbstractRule,
47-
Avg<:AbstractAverager,
48-
Op<:AbstractOperator,
49-
} <: AbstractVariationalAlgorithm
50-
objective::Obj
51-
adtype::AD
52-
optimizer::Opt
53-
averager::Avg
54-
operator::Op
55-
end
56-
57-
struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
58-
prob::P
59-
q::Q
60-
iteration::Int
61-
grad_buf::GradBuf
62-
opt_st::OptSt
63-
obj_st::ObjSt
64-
avg_st::AvgSt
65-
end
2+
const ParamSpaceSGD = Union{
3+
<:KLMinRepGradDescent,
4+
<:KLMinRepGradProxDescent,
5+
<:KLMinScoreGradDescent,
6+
}
667

678
function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
689
(; adtype, optimizer, averager, objective, operator) = alg
@@ -76,7 +17,15 @@ function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
7617
obj_st = init(rng, objective, adtype, q_init, prob, params, re)
7718
avg_st = init(averager, params)
7819
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
79-
return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
20+
if alg isa KLMinRepGradDescent
21+
return KLMinRepGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
22+
elseif alg isa KLMinRepGradProxDescent
23+
return KLMinRepGradProxDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
24+
elseif alg isa KLMinScoreGradDescent
25+
return KLMinScoreGradDescentState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
26+
else
27+
nothing
28+
end
8029
end
8130

8231
function output(alg::ParamSpaceSGD, state)
@@ -104,9 +53,21 @@ function step(
10453
params = apply(operator, typeof(q), opt_st, params, re)
10554
avg_st = apply(averager, avg_st, params)
10655

107-
state = ParamSpaceSGDState(
108-
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
109-
)
56+
state = if alg isa KLMinRepGradDescent
57+
KLMinRepGradDescentState(
58+
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
59+
)
60+
elseif alg isa KLMinRepGradProxDescent
61+
KLMinRepGradProxDescentState(
62+
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
63+
)
64+
elseif alg isa KLMinScoreGradDescent
65+
KLMinScoreGradDescentState(
66+
prob, re(params), iteration, grad_buf, opt_st, obj_st, avg_st
67+
)
68+
else
69+
nothing
70+
end
11071

11172
if !isnothing(callback)
11273
averaged_params = value(averager, avg_st)

test/general/optimize.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
(; model, μ_true, L_true, n_dims, is_meanfield) = modelstats
1010

1111
q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims)))
12-
obj = RepGradELBO(10)
13-
14-
optimizer = Optimisers.Adam(1e-2)
15-
averager = PolynomialAveraging()
16-
17-
alg = ParamSpaceSGD(obj, AD, optimizer, averager, IdentityOperator())
12+
alg = KLMinRepGradDescent(AD; optimizer=Optimisers.Adam(1e-2), operator=ClipScale())
1813

1914
@testset "default_rng" begin
2015
optimize(alg, T, model, q0; show_progress=false)

0 commit comments

Comments
 (0)