Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Therefore, in case a variational family of `<:MvLocationScale` is used in combin

## Interface Changes

An additional layer of indirection, `AbstractAlgorithms` has been added.
An additional layer of indirection, `AbstractVariationalAlgorithms` has been added.
Previously, all variational inference algorithms were assumed to run SGD in parameter space.
This desing however, is proving to be too rigid.
Instead, each algorithm is now assumed to implement three simple interfaces: `init`, `step`, and `output`.
Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractAlgorithms`, which is a general implementation of the new interface.
Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractVariationalAlgorithms`, which is a general implementation of the new interface.
Therefore, the old behavior of `AdvancedVI` is fully inhereted by `ParamSpaceSGD`.

## Internal Changes
Expand Down
10 changes: 5 additions & 5 deletions docs/src/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# [General Usage](@id general)

AdvancedVI provides multiple variational inference (VI) algorithms.
Each algorithm defines its subtype of [`AdvancedVI.AbstractAlgorithm`](@ref) with some corresponding methods (see [this section](@ref algorithm)).
Each algorithm defines its subtype of [`AdvancedVI.AbstractVariationalAlgorithm`](@ref) with some corresponding methods (see [this section](@ref algorithm)).
Then the algorithm can be executed by invoking `optimize`. (See [this section](@ref optimize)).

## [Optimize](@id optimize)

Given a subtype of `AbstractAlgorithm` associated with each algorithm, it suffices to call the function `optimize`:
Given a subtype of `AbstractVariationalAlgorithm` associated with each algorithm, it suffices to call the function `optimize`:

```@docs
optimize
Expand All @@ -18,16 +18,16 @@ Therefore, please refer to the documentation of each different algorithm for a d

## [Algorithm Interface](@id algorithm)

A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractAlgorithm`:
A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractVariationalAlgorithm`:

```@docs
AdvancedVI.AbstractAlgorithm
AdvancedVI.AbstractVariationalAlgorithm
```

The functionality of each algorithm is then implemented through the following methods:

```@docs
AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any, ::Any)
AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractVariationalAlgorithm, ::Any, ::Any)
AdvancedVI.step
AdvancedVI.output
```
Expand Down
21 changes: 13 additions & 8 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,11 @@ export IdentityOperator, ClipScale, ProximalLocationScaleEntropy
# Algorithms

"""
AbstractAlgorithm
AbstractVariationalAlgorithm

Abstract type for a variational inference algorithm.
"""
abstract type AbstractAlgorithm end
abstract type AbstractVariationalAlgorithm end

"""
init(rng, alg, q_init, prob)
Expand All @@ -205,14 +205,14 @@ Initialize `alg` given the initial variational approximation `q_init` and the ta

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::AbstractAlgorithm`: Variational inference algorithm.
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm.
- `q_init`: Initial variational approximation.
- `prob`: Target problem.

# Returns
- `state`: Initial state of the algorithm.
"""
init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, ::Any) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalAlgorithm, ::Any, ::Any) = nothing

"""
step(rng, alg, state, callback, objargs...; kwargs...)
Expand All @@ -221,7 +221,7 @@ Perform a single step of `alg` given the previous `state`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::AbstractAlgorithm`: Variational inference algorithm.
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm.
- `state`: Previous state of the algorithm.
- `callback`: Callback function to be called during the step.

Expand All @@ -231,7 +231,12 @@ Perform a single step of `alg` given the previous `state`.
- `info`::NamedTuple: Information generated during the step.
"""
function step(
::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs...
::Random.AbstractRNG,
::AbstractVariationalAlgorithm,
::Any,
callback,
objargs...;
kwargs...,
)
nothing
end
Expand All @@ -242,13 +247,13 @@ end
Output a variational approximation from the last `state` of `alg`.

# Arguments
- `alg::AbstractAlgorithm`: Variational inference algorithm used to compute the state.
- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm used to compute the state.
- `state`: The last state generated by the algorithm.

# Returns
- `out`: The output of the algorithm.
"""
output(::AbstractAlgorithm, ::Any) = nothing
output(::AbstractVariationalAlgorithm, ::Any) = nothing

# Subsampling
"""
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/paramspacesgd/paramspacesgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct ParamSpaceSGD{
Opt<:Optimisers.AbstractRule,
Avg<:AbstractAverager,
Op<:AbstractOperator,
} <: AbstractAlgorithm
} <: AbstractVariationalAlgorithm
objective::Obj
adtype::AD
optimizer::Opt
Expand Down
7 changes: 1 addition & 6 deletions src/algorithms/paramspacesgd/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ function init(
MixedADLogDensityProblem(prob)
end
aux = (
rng=rng,
adtype=adtype,
obj=obj,
problem=ad_prob,
restructure=restructure,
q_stop=q,
rng=rng, adtype=adtype, obj=obj, problem=ad_prob, restructure=restructure, q_stop=q
)
obj_ad_prep = AdvancedVI._prepare_gradient(
estimate_repgradelbo_ad_forward, adtype, params, aux
Expand Down
11 changes: 8 additions & 3 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
optimize(
[rng::Random.AbstractRNG = Random.default_rng(),]
algorithm::AbstractAlgorithm,
algorithm::AbstractVariationalAlgorithm,
max_iter::Int,
prob,
q_init,
Expand Down Expand Up @@ -41,7 +41,7 @@ The content of the `NamedTuple` will be concatenated into the corresponding entr
"""
function optimize(
rng::Random.AbstractRNG,
algorithm::AbstractAlgorithm,
algorithm::AbstractVariationalAlgorithm,
max_iter::Int,
prob,
q_init,
Expand Down Expand Up @@ -81,7 +81,12 @@ function optimize(
end

function optimize(
algorithm::AbstractAlgorithm, max_iter::Int, prob, q_init, objargs...; kwargs...
algorithm::AbstractVariationalAlgorithm,
max_iter::Int,
prob,
q_init,
objargs...;
kwargs...,
)
return optimize(
Random.default_rng(), algorithm, max_iter, prob, q_init, objargs...; kwargs...
Expand Down
Loading