diff --git a/HISTORY.md b/HISTORY.md index c1a9bb01..ea309400 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,11 +10,11 @@ Therefore, in case a variational family of `<:MvLocationScale` is used in combin ## Interface Changes -An additional layer of indirection, `AbstractAlgorithms` has been added. +An additional layer of indirection, `AbstractVariationalAlgorithms` has been added. Previously, all variational inference algorithms were assumed to run SGD in parameter space. This desing however, is proving to be too rigid. Instead, each algorithm is now assumed to implement three simple interfaces: `init`, `step`, and `output`. -Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractAlgorithms`, which is a general implementation of the new interface. +Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractVariationalAlgorithms`, which is a general implementation of the new interface. Therefore, the old behavior of `AdvancedVI` is fully inhereted by `ParamSpaceSGD`. ## Internal Changes diff --git a/docs/src/general.md b/docs/src/general.md index 417d6b22..b1c5f7c4 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -2,12 +2,12 @@ # [General Usage](@id general) AdvancedVI provides multiple variational inference (VI) algorithms. -Each algorithm defines its subtype of [`AdvancedVI.AbstractAlgorithm`](@ref) with some corresponding methods (see [this section](@ref algorithm)). +Each algorithm defines its subtype of [`AdvancedVI.AbstractVariationalAlgorithm`](@ref) with some corresponding methods (see [this section](@ref algorithm)). Then the algorithm can be executed by invoking `optimize`. (See [this section](@ref optimize)). ## [Optimize](@id optimize) -Given a subtype of `AbstractAlgorithm` associated with each algorithm, it suffices to call the function `optimize`: +Given a subtype of `AbstractVariationalAlgorithm` associated with each algorithm, it suffices to call the function `optimize`: ```@docs optimize @@ -18,16 +18,16 @@ Therefore, please refer to the documentation of each different algorithm for a d ## [Algorithm Interface](@id algorithm) -A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractAlgorithm`: +A variational inference algorithm supported by `AdvancedVI` should define its own subtype of `AbstractVariationalAlgorithm`: ```@docs -AdvancedVI.AbstractAlgorithm +AdvancedVI.AbstractVariationalAlgorithm ``` The functionality of each algorithm is then implemented through the following methods: ```@docs -AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractAlgorithm, ::Any, ::Any) +AdvancedVI.init(::Random.AbstractRNG, ::AdvancedVI.AbstractVariationalAlgorithm, ::Any, ::Any) AdvancedVI.step AdvancedVI.output ``` diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index f3b10653..6f51a9f0 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -192,11 +192,11 @@ export IdentityOperator, ClipScale, ProximalLocationScaleEntropy # Algorithms """ - AbstractAlgorithm + AbstractVariationalAlgorithm Abstract type for a variational inference algorithm. """ -abstract type AbstractAlgorithm end +abstract type AbstractVariationalAlgorithm end """ init(rng, alg, q_init, prob) @@ -205,14 +205,14 @@ Initialize `alg` given the initial variational approximation `q_init` and the ta # Arguments - `rng::Random.AbstractRNG`: Random number generator. -- `alg::AbstractAlgorithm`: Variational inference algorithm. +- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm. - `q_init`: Initial variational approximation. - `prob`: Target problem. # Returns - `state`: Initial state of the algorithm. """ -init(::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, ::Any) = nothing +init(::Random.AbstractRNG, ::AbstractVariationalAlgorithm, ::Any, ::Any) = nothing """ step(rng, alg, state, callback, objargs...; kwargs...) @@ -221,7 +221,7 @@ Perform a single step of `alg` given the previous `state`. # Arguments - `rng::Random.AbstractRNG`: Random number generator. -- `alg::AbstractAlgorithm`: Variational inference algorithm. +- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm. - `state`: Previous state of the algorithm. - `callback`: Callback function to be called during the step. @@ -231,7 +231,12 @@ Perform a single step of `alg` given the previous `state`. - `info`::NamedTuple: Information generated during the step. """ function step( - ::Random.AbstractRNG, ::AbstractAlgorithm, ::Any, callback, objargs...; kwargs... + ::Random.AbstractRNG, + ::AbstractVariationalAlgorithm, + ::Any, + callback, + objargs...; + kwargs..., ) nothing end @@ -242,13 +247,13 @@ end Output a variational approximation from the last `state` of `alg`. # Arguments -- `alg::AbstractAlgorithm`: Variational inference algorithm used to compute the state. +- `alg::AbstractVariationalAlgorithm`: Variational inference algorithm used to compute the state. - `state`: The last state generated by the algorithm. # Returns - `out`: The output of the algorithm. """ -output(::AbstractAlgorithm, ::Any) = nothing +output(::AbstractVariationalAlgorithm, ::Any) = nothing # Subsampling """ diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index bc2dd319..92bbb0e5 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -46,7 +46,7 @@ struct ParamSpaceSGD{ Opt<:Optimisers.AbstractRule, Avg<:AbstractAverager, Op<:AbstractOperator, -} <: AbstractAlgorithm +} <: AbstractVariationalAlgorithm objective::Obj adtype::AD optimizer::Opt diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index a1cdcbd9..cd0fbf1b 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -61,12 +61,7 @@ function init( MixedADLogDensityProblem(prob) end aux = ( - rng=rng, - adtype=adtype, - obj=obj, - problem=ad_prob, - restructure=restructure, - q_stop=q, + rng=rng, adtype=adtype, obj=obj, problem=ad_prob, restructure=restructure, q_stop=q ) obj_ad_prep = AdvancedVI._prepare_gradient( estimate_repgradelbo_ad_forward, adtype, params, aux diff --git a/src/optimize.jl b/src/optimize.jl index 19bdc525..116ec88f 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -2,7 +2,7 @@ """ optimize( [rng::Random.AbstractRNG = Random.default_rng(),] - algorithm::AbstractAlgorithm, + algorithm::AbstractVariationalAlgorithm, max_iter::Int, prob, q_init, @@ -41,7 +41,7 @@ The content of the `NamedTuple` will be concatenated into the corresponding entr """ function optimize( rng::Random.AbstractRNG, - algorithm::AbstractAlgorithm, + algorithm::AbstractVariationalAlgorithm, max_iter::Int, prob, q_init, @@ -81,7 +81,12 @@ function optimize( end function optimize( - algorithm::AbstractAlgorithm, max_iter::Int, prob, q_init, objargs...; kwargs... + algorithm::AbstractVariationalAlgorithm, + max_iter::Int, + prob, + q_init, + objargs...; + kwargs..., ) return optimize( Random.default_rng(), algorithm, max_iter, prob, q_init, objargs...; kwargs...