Skip to content

Commit e8f00d8

Browse files
committed
rename AbstractAlgorithm to AbstractVariationalAlgorithm
1 parent f47189d commit e8f00d8

File tree

6 files changed

+29
-24
lines changed

6 files changed

+29
-24
lines changed

HISTORY.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ Therefore, in case a variational family of `<:MvLocationScale` is used in combin
1010

1111
## Interface Changes
1212

13-
An additional layer of indirection, `AbstractAlgorithms` has been added.
13+
An additional layer of indirection, `AbstractVariationalAlgorithms` has been added.
1414
Previously, all variational inference algorithms were assumed to run SGD in parameter space.
1515
This desing however, is proving to be too rigid.
1616
Instead, each algorithm is now assumed to implement three simple interfaces: `init`, `step`, and `output`.
17-
Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractAlgorithms`, which is a general implementation of the new interface.
17+
Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractVariationalAlgorithms`, which is a general implementation of the new interface.
1818
Therefore, the old behavior of `AdvancedVI` is fully inhereted by `ParamSpaceSGD`.
1919

2020
## Internal Changes

docs/src/general.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Then the algorithm can be executed by invoking `optimize`. (See [this section](@
77

88
## [Optimize](@id optimize)
99

10-
Given a subtype of `AbstractAlgorithm` associated with each algorithm, it suffices to call the function `optimize`:
10+
Given a subtype of `AbstractVariationalAlgorithm` associated with each algorithm, it suffices to call the function `optimize`:
1111

1212
```@docs
1313
optimize
@@ -18,16 +18,16 @@ Therefore, please refer to the documentation of each different algorithm for a d
1818

1919
## [Algorithm Interface](@id algorithm)
2020

21-
A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractAlgorithm`:
21+
A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractVariationalAlgorithm`:
2222

2323
```@docs
24-
AdvancedVI.AbstractAlgorithm
24+
AdvancedVI.AbstractVariationalAlgorithm
2525
```
2626

2727
The functionality of each algorithm is then implemented through the following methods:
2828

2929
```@docs
30-
AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any, ::Any)
30+
AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractVariationalAlgorithm, ::Any, ::Any)
3131
AdvancedVI.step
3232
AdvancedVI.output
3333
```

src/AdvancedVI.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ export IdentityOperator, ClipScale, ProximalLocationScaleEntropy
192192
# Algorithms
193193

194194
"""
195-
AbstractAlgorithm
195+
AbstractVariationalAlgorithm
196196
197197
Abstract type for a variational inference algorithm.
198198
"""
199-
abstract type AbstractAlgorithm end
199+
abstract type AbstractVariationalAlgorithm end
200200

201201
"""
202202
init(rng, alg, q_init, prob)
@@ -205,14 +205,14 @@ Initialize `alg` given the initial variational approximation `q_init` and the ta
205205
206206
# Arguments
207207
- `rng::Random.AbstractRNG`: Random number generator.
208-
- `alg::AbstractAlgorithm`: Variational inference algorithm.
208+
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm.
209209
- `q_init`: Initial variational approximation.
210210
- `prob`: Target problem.
211211
212212
# Returns
213213
- `state`: Initial state of the algorithm.
214214
"""
215-
init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, ::Any) = nothing
215+
init(::Random.AbstractRNG, ::AbstractVariationalAlgorithm, ::Any, ::Any) = nothing
216216

217217
"""
218218
step(rng, alg, state, callback, objargs...; kwargs...)
@@ -221,7 +221,7 @@ Perform a single step of `alg` given the previous `state`.
221221
222222
# Arguments
223223
- `rng::Random.AbstractRNG`: Random number generator.
224-
- `alg::AbstractAlgorithm`: Variational inference algorithm.
224+
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm.
225225
- `state`: Previous state of the algorithm.
226226
- `callback`: Callback function to be called during the step.
227227
@@ -231,7 +231,12 @@ Perform a single step of `alg` given the previous `state`.
231231
- `info`::NamedTuple: Information generated during the step.
232232
"""
233233
function step(
234-
::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs...
234+
::Random.AbstractRNG,
235+
::AbstractVariationalAlgorithm,
236+
::Any,
237+
callback,
238+
objargs...;
239+
kwargs...,
235240
)
236241
nothing
237242
end
@@ -242,13 +247,13 @@ end
242247
Output a variational approximation from the last `state` of `alg`.
243248
244249
# Arguments
245-
- `alg::AbstractAlgorithm`: Variational inference algorithm used to compute the state.
250+
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm used to compute the state.
246251
- `state`: The last state generated by the algorithm.
247252
248253
# Returns
249254
- `out`: The output of the algorithm.
250255
"""
251-
output(::AbstractAlgorithm, ::Any) = nothing
256+
output(::AbstractVariationalAlgorithm, ::Any) = nothing
252257

253258
# Subsampling
254259
"""

src/algorithms/paramspacesgd/paramspacesgd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct ParamSpaceSGD{
4646
Opt<:Optimisers.AbstractRule,
4747
Avg<:AbstractAverager,
4848
Op<:AbstractOperator,
49-
} <: AbstractAlgorithm
49+
} <: AbstractVariationalAlgorithm
5050
objective::Obj
5151
adtype::AD
5252
optimizer::Opt

src/algorithms/paramspacesgd/repgradelbo.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ function init(
6161
MixedADLogDensityProblem(prob)
6262
end
6363
aux = (
64-
rng=rng,
65-
adtype=adtype,
66-
obj=obj,
67-
problem=ad_prob,
68-
restructure=restructure,
69-
q_stop=q,
64+
rng=rng, adtype=adtype, obj=obj, problem=ad_prob, restructure=restructure, q_stop=q
7065
)
7166
obj_ad_prep = AdvancedVI._prepare_gradient(
7267
estimate_repgradelbo_ad_forward, adtype, params, aux

src/optimize.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
optimize(
44
[rng::Random.AbstractRNG = Random.default_rng(),]
5-
algorithm::AbstractAlgorithm,
5+
algorithm::AbstractVariationalAlgorithm,
66
max_iter::Int,
77
prob,
88
q_init,
@@ -41,7 +41,7 @@ The content of the `NamedTuple` will be concatenated into the corresponding entr
4141
"""
4242
function optimize(
4343
rng::Random.AbstractRNG,
44-
algorithm::AbstractAlgorithm,
44+
algorithm::AbstractVariationalAlgorithm,
4545
max_iter::Int,
4646
prob,
4747
q_init,
@@ -81,7 +81,12 @@ function optimize(
8181
end
8282

8383
function optimize(
84-
algorithm::AbstractAlgorithm, max_iter::Int, prob, q_init, objargs...; kwargs...
84+
algorithm::AbstractVariationalAlgorithm,
85+
max_iter::Int,
86+
prob,
87+
q_init,
88+
objargs...;
89+
kwargs...,
8590
)
8691
return optimize(
8792
Random.default_rng(), algorithm, max_iter, prob, q_init, objargs...; kwargs...

0 commit comments

Comments
 (0)