-
Notifications
You must be signed in to change notification settings - Fork 19
update the example in the README #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| ```julia | ||
| using LogDensityProblems | ||
| import LogDensityProblems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| import LogDensityProblems | |
| using LogDensityProblems: LogDensityProblems |
| ``` | ||
|
|
||
| Since the support of `x` is constrained to be positive and VI is best done in the unconstrained Euclidean space, we need to use a *bijector* to transform `x` into unconstrained Euclidean space. We will use the [`Bijectors.jl`](https://github.com/TuringLang/Bijectors.jl) package for this purpose. | ||
| Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. | |
| Since the support of `σ` is constrained to be positive and most VI algorithms assume an unconstrained Euclidean support, we need to use a *bijector* to transform `θ`. |
|
|
||
| ```julia | ||
| using Bijectors | ||
| import Bijectors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| import Bijectors | |
| using Bijectors: Bijectors |
| import OpenML | ||
| import DataFrames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| import OpenML | |
| import DataFrames | |
| using OpenML: OpenML | |
| using DataFrames: DataFrames |
| X = Matrix{Float64}(data[:, 1:(end - 1)]) | ||
| y = Vector{Bool}(data[:, end] .== "Mine") | ||
| ``` | ||
| Let's apply some basic pre-processing and add an intercept column: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| Let's apply some basic pre-processing and add an intercept column: | |
| Let's apply some basic pre-processing and add an intercept column: |
| d = LogDensityProblems.dimension(model_ad) | ||
| q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))) | ||
| ``` | ||
| The bijector can now be applied to `q` to match the support of the target problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| The bijector can now be applied to `q` to match the support of the target problem. | |
| The bijector can now be applied to `q` to match the support of the target problem. |
| q = MeanFieldGaussian(zeros(d), Diagonal(ones(d))) | ||
| ``` | ||
| The bijector can now be applied to `q` to match the support of the target problem. | ||
| ```julia |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| ```julia | |
| ```julia |
|
|
||
| # Run inference | ||
| ``` | ||
| We can now run VI: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| We can now run VI: | |
| We can now run VI: |
| # Run inference | ||
| ``` | ||
| We can now run VI: | ||
| ```julia |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| ```julia | |
| ```julia |
| q_avg, info, _ = AdvancedVI.optimize( | ||
| alg, | ||
| max_iter, | ||
| model, | ||
| model_ad, | ||
| q_transformed; | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| q_avg, info, _ = AdvancedVI.optimize( | |
| alg, | |
| max_iter, | |
| model, | |
| model_ad, | |
| q_transformed; | |
| ) | |
| q_avg, info, _ = AdvancedVI.optimize(alg, max_iter, model_ad, q_transformed;) |
|
AdvancedVI.jl documentation for PR #193 is available at: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
| Benchmark suite | Current: 4711e40 | Previous: 864f5e9 | Ratio |
|---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
3884113557 ns |
3940463414.5 ns |
0.99 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
1087580354 ns |
1080457534 ns |
1.01 |
normal/RepGradELBO + STL/meanfield/Mooncake |
1187372637 ns |
1186415910 ns |
1.00 |
normal/RepGradELBO + STL/fullrank/Zygote |
3876169524.5 ns |
3880348067.5 ns |
1.00 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
1586673546 ns |
1578889137 ns |
1.00 |
normal/RepGradELBO + STL/fullrank/Mooncake |
1230627189 ns |
1245752131 ns |
0.99 |
normal/RepGradELBO/meanfield/Zygote |
2729318274.5 ns |
2745244233 ns |
0.99 |
normal/RepGradELBO/meanfield/ReverseDiff |
753836095 ns |
779172740 ns |
0.97 |
normal/RepGradELBO/meanfield/Mooncake |
1053213407 ns |
1065162127 ns |
0.99 |
normal/RepGradELBO/fullrank/Zygote |
2746187641.5 ns |
2779479552 ns |
0.99 |
normal/RepGradELBO/fullrank/ReverseDiff |
938784644 ns |
922764899.5 ns |
1.02 |
normal/RepGradELBO/fullrank/Mooncake |
1101792362 ns |
1099112093 ns |
1.00 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
5608935852 ns |
5520186124 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
2347525187 ns |
2315766289 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/Mooncake |
4004306569 ns |
3960343348 ns |
1.01 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
5631136713 ns |
5464591378 ns |
1.03 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
2965503150.5 ns |
2899203076.5 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/Mooncake |
4171209341.5 ns |
4077037587 ns |
1.02 |
normal + bijector/RepGradELBO/meanfield/Zygote |
4287497307.5 ns |
4259257598 ns |
1.01 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
2016252830 ns |
1969823950 ns |
1.02 |
normal + bijector/RepGradELBO/meanfield/Mooncake |
3859733151.5 ns |
3921065959 ns |
0.98 |
normal + bijector/RepGradELBO/fullrank/Zygote |
4364530532 ns |
4347716606 ns |
1.00 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
2290625728 ns |
2217527377 ns |
1.03 |
normal + bijector/RepGradELBO/fullrank/Mooncake |
4015809201.5 ns |
3950691343.5 ns |
1.02 |
This comment was automatically generated by workflow using github-action-benchmark.
No description provided.