-
Notifications
You must be signed in to change notification settings - Fork 19
Remove the Bijectors extension
#219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
4bf2fce
remove the bijectors extension and relevant tests, benchmarks
Red-Portal 9dc43bb
fix remove use of Bijectors in tests
Red-Portal 9479ed4
fix typo in HISTORY
Red-Portal 0142a82
bump AdvancedVI version
Red-Portal bb47ced
run formatter
Red-Portal 6fa2bc3
run formatter
Red-Portal d6157f1
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal e201345
fix removed include to removed file
Red-Portal 66a22a5
fix remove erroenous references to bijector
Red-Portal f5f7fcb
fix missing comma
Red-Portal 7a6e902
fix remove include to removed file
Red-Portal 2a1755a
update READMe
Red-Portal f23eb7a
fix move benchmark model to main file
Red-Portal 4d8d95e
update docs to the new recommended use of Bijectors
Red-Portal fba3d9f
run formatter
Red-Portal 4daaa79
run formatter
Red-Portal a0f13d5
run formatter
Red-Portal 9a64db7
run formatter
Red-Portal 048a310
add constraint tutorial
Red-Portal 69ae57a
fix missing import in docs
Red-Portal af4ad18
run furmatter to constrained
Red-Portal 0386029
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal f9d7f0b
fix typo
Red-Portal d6dede2
fix bins in normalizing flow tutorial
Red-Portal c7e8c44
fix formatting
Red-Portal 373abdd
run formatter
Red-Portal 6a5c7ba
revert changes to documentation
Red-Portal c05e2f5
Merge branch 'remove_bijectors' of github.com:TuringLang/AdvancedVI.j…
Red-Portal 8c657e0
revert changes to the README
Red-Portal 92f5077
Revert "revert changes to the README"
Red-Portal f81398a
Revert "revert changes to documentation"
Red-Portal 6bf17b3
run formatter
Red-Portal 8c729c5
fix use more sophisticated `TransformedLogDensityProblem`
Red-Portal db15461
fix wrong type of markdown env
Red-Portal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| using ADTypes | ||
| using AdvancedVI | ||
| using BenchmarkTools | ||
| using Bijectors | ||
| using Distributions | ||
| using DistributionsAD | ||
| using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake | ||
|
|
@@ -17,8 +16,34 @@ BLAS.set_num_threads(min(4, Threads.nthreads())) | |
| @info sprint(versioninfo) | ||
| @info "BLAS threads: $(BLAS.get_num_threads())" | ||
|
|
||
| include("normallognormal.jl") | ||
| include("unconstrdist.jl") | ||
| struct Dist{D<:ContinuousMultivariateDistribution} | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The content of |
||
| dist::D | ||
| end | ||
|
|
||
| function LogDensityProblems.logdensity(model::Dist, x) | ||
| return logpdf(model.dist, x) | ||
| end | ||
|
|
||
| function LogDensityProblems.logdensity_and_gradient(model::Dist, θ) | ||
| return ( | ||
| LogDensityProblems.logdensity(model, θ), | ||
| ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), | ||
| ) | ||
| end | ||
|
|
||
| function LogDensityProblems.dimension(model::Dist) | ||
| return length(model.dist) | ||
| end | ||
|
|
||
| function LogDensityProblems.capabilities(::Type{<:Dist}) | ||
| return LogDensityProblems.LogDensityOrder{0}() | ||
| end | ||
|
|
||
| function normal(; n_dims=10, realtype=Float64) | ||
| μ = fill(realtype(5), n_dims) | ||
| Σ = Diagonal(ones(realtype, n_dims)) | ||
| return Dist(MvNormal(μ, Σ)) | ||
| end | ||
|
|
||
| const SUITES = BenchmarkGroup() | ||
|
|
||
|
|
@@ -33,10 +58,7 @@ end | |
| begin | ||
| T = Float64 | ||
|
|
||
| for (probname, prob) in [ | ||
| ("normal + bijector", normallognormal(; n_dims=10, realtype=T)) | ||
| ("normal", normal(; n_dims=10, realtype=T)) | ||
| ] | ||
| for (probname, prob) in [("normal", normal(; n_dims=10, realtype=T))] | ||
| max_iter = 10^4 | ||
| d = LogDensityProblems.dimension(prob) | ||
| opt = Optimisers.Adam(T(1e-3)) | ||
|
|
@@ -59,9 +81,7 @@ begin | |
| ), | ||
| ] | ||
|
|
||
| b = Bijectors.bijector(prob) | ||
| binv = inverse(b) | ||
| q = Bijectors.TransformedDistribution(family, binv) | ||
| q = family | ||
| alg = KLMinRepGradDescent(adtype; optimizer=opt, entropy, operator=ClipScale()) | ||
|
|
||
| SUITES[probname][objname][familyname][adname] = begin | ||
|
|
||
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this code is always the same, it seems like it might be worth putting in a library somewhere... although AdvancedVI isn't the right place for it... maybe Bijectors?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, you might already be aware, but there is an infamously pathological DynamicPPL edge case with something like this:
(of course you don't need DynamicPPL to make a logdensityfunction like that, you could construct it with plain logpdf). The issue is that the bijector cannot be statically constructed, because the bijector for
ywill depend on the value ofx. It's still a 1 to 1 function since the inputs fully determine the bijector which fully determines the output. But I think it might not fit nicely into the structure above. Probably you have to do the hardcoded logjac calculation way.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
TransformedLogDensities.jlwas supposed to serve this purpose, but not sure why it never decided to workBijectors.How are the MCMC algorithms handling this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turing doesn't have a single bijector per model, instead there's one bijector per variable, which is constructed at runtime depending on the distribution on the rhs of tilde.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Then I guess this is going to be a known issue for AdvancedVI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's an unsolvable problem though. I feel like if anything it should be fairly easy to support (don't hardcode a single bijector but rather use LogDensityFunction with linking).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Not in this PR of course)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I am planning to do (in fact already implemented in the open PR downstream) for the Turing interface. But the problem is that the output
qneeds to be wrapped with aBijectors.TransformedDistribution. I don't see any obvious way around this.