Skip to content

Conversation

@Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 21, 2025

Consider the case where we would like to approximate a constrained target distribution with density $\pi : \mathcal{X} \to \mathbb{R}{> 0}$ with an unconstrained variational approximation with density $q : \mathbb{R}^d \to \mathbb{R}{> 0}$. The canonical way to deal with this, popularized by the ADVI paper1, is to use a $b$ bijective transformation ("Bijectors") $b : \mathbb{R}^d \to \mathcal{X}$ such that $q$ is augmented into $q_{b}$ as

$$q_{b^{-1}}(z) = q(b^{-1}(z)) {\lvert \mathrm{J}_{b^{-1}}(z) \rvert}$$

Then AdvancedVI needs to solve the problem

$$q_{b^{-1}}^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q_{b^{-1}}, \pi_b) .$$

But notice that the optimization is, in reality, over $q$. Therefore, often times, AdvancedVI needs access to the underlying q. I will refer to this as the "primal" scheme.

Previously, this was done by giving a special treatment to q <: Bijectors.TransformedDistribution through the Bijectors extension. In particular, the Bijectors extension had to add a specialization to a lot of methods that simply unwrap a TransformedDistribution to do something. This behavior is difficult to document and, therefore, wasn't fully explained in the documentation. Furthermore, each of the relevant methods needs to be specialized in the Bijectors extension, which resulted in a multiplicative complexity, especially for unit testing.

This, however, is unnecessary. Instead, there exists an equivalent "dual" problem that operates in unconstrained space by approximating the transformed posterior

$$\pi_b(\eta) = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} .$$

That is, we can solve the problem

$$q^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q, \pi_b)$$

and then post-process the output to retrieve $q_{b^{-1}}^*$.

Within this context, this PR removes the Bijectors extension to fix this problem. Here are the reationals:

  • As mentioned above, AdvancedVI doesn't need to implement the primal scheme. In fact, the upcoming interface in Turing is planned to implement the dual scheme above.
  • The new algorithms KLMinNaturalGradDescent, KLMinWassFwdBwd, FisherMinBatchMatch, for example, do not work in constrained support at all, so they can only be used via the dual scheme. So the way that KLMinRepGradDescent and friends implemented the primal scheme is a bit redundant in terms of consistency at this point.

Instead, a tutorial has been added to the documentation on how to use VI with constrained supports via the dual scheme.

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45.


include("normallognormal.jl")
include("unconstrdist.jl")
struct Dist{D<:ContinuousMultivariateDistribution}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content of unconstrdist.jl have been moved here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: db15461 Previous: f81398a Ratio
normal/RepGradELBO + STL/meanfield/Zygote 2460531085 ns 2648590549.5 ns 0.93
normal/RepGradELBO + STL/meanfield/ReverseDiff 593417542 ns 609698835 ns 0.97
normal/RepGradELBO + STL/meanfield/Mooncake 245091874 ns 250082202 ns 0.98
normal/RepGradELBO + STL/fullrank/Zygote 1932656603 ns 2086072781 ns 0.93
normal/RepGradELBO + STL/fullrank/ReverseDiff 1115706204 ns 1178378740 ns 0.95
normal/RepGradELBO + STL/fullrank/Mooncake 659773727.5 ns 684388235.5 ns 0.96
normal/RepGradELBO/meanfield/Zygote 1514658031 ns 1576637144.5 ns 0.96
normal/RepGradELBO/meanfield/ReverseDiff 302198740 ns 310216571 ns 0.97
normal/RepGradELBO/meanfield/Mooncake 171503791.5 ns 174420549.5 ns 0.98
normal/RepGradELBO/fullrank/Zygote 1053510637 ns 1103228381 ns 0.95
normal/RepGradELBO/fullrank/ReverseDiff 590708132 ns 606753640 ns 0.97
normal/RepGradELBO/fullrank/Mooncake 542788814 ns 569936275 ns 0.95

This comment was automatically generated by workflow using github-action-benchmark.

Red-Portal and others added 10 commits November 22, 2025 11:53
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@github-actions
Copy link
Contributor

AdvancedVI.jl documentation for PR #219 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR219/

Red-Portal and others added 2 commits November 22, 2025 15:09
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal
Copy link
Member Author

The updates to the documentation and README have been suppressed for clarity and will be added later once the PR is approved.

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good! But could you fix the docs build?

By the way, you can also hide the nothing at the end of the docs blocks if you want so that it doesn't get rendered (https://documenter.juliadocs.org/stable/man/syntax/#reference-at-example):

```@example myex
...
nothing # hide
```

@penelopeysm
Copy link
Member

To be clear, I'm not in favour of docs and readme being a separate PR. I get it's a big change, but I think that it's better for PRs to be atomic.

@Red-Portal
Copy link
Member Author

Red-Portal commented Nov 25, 2025

@penelopeysm I've added the updates to the README and the docs as requested.

That # hide trick, good to know! Let me punt that for now but I'll definitely use that later.

Comment on lines +91 to +95
```julia
struct TransformedLogDensityProblem{Prob,BInv}
prob::Prob
binv::BInv
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is always the same, it seems like it might be worth putting in a library somewhere... although AdvancedVI isn't the right place for it... maybe Bijectors?

Copy link
Member

@penelopeysm penelopeysm Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, you might already be aware, but there is an infamously pathological DynamicPPL edge case with something like this:

@model function f()
    x ~ Normal()
    y ~ truncated(Normal(); lower = x)
end

(of course you don't need DynamicPPL to make a logdensityfunction like that, you could construct it with plain logpdf). The issue is that the bijector cannot be statically constructed, because the bijector for y will depend on the value of x. It's still a 1 to 1 function since the inputs fully determine the bijector which fully determines the output. But I think it might not fit nicely into the structure above. Probably you have to do the hardcoded logjac calculation way.

Copy link
Member Author

@Red-Portal Red-Portal Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is always the same, it seems like it might be worth putting in a library somewhere... although AdvancedVI isn't the right place for it... maybe Bijectors?

I think TransformedLogDensities.jl was supposed to serve this purpose, but not sure why it never decided to work Bijectors.

BTW, you might already be aware, but there is an infamously pathological DynamicPPL edge case with something like this:

How are the MCMC algorithms handling this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turing doesn't have a single bijector per model, instead there's one bijector per variable, which is constructed at runtime depending on the distribution on the rhs of tilde.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then I guess this is going to be a known issue for AdvancedVI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's an unsolvable problem though. I feel like if anything it should be fairly easy to support (don't hardcode a single bijector but rather use LogDensityFunction with linking).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not in this PR of course)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(don't hardcode a single bijector but rather use LogDensityFunction with linking).

That's what I am planning to do (in fact already implemented in the open PR downstream) for the Turing interface. But the problem is that the output q needs to be wrapped with a Bijectors.TransformedDistribution. I don't see any obvious way around this.

@Red-Portal Red-Portal merged commit 72df3af into main Nov 26, 2025
34 of 40 checks passed
@Red-Portal Red-Portal deleted the remove_bijectors branch November 26, 2025 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants