Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Setfield = "0.7, 0.8, 1"
Statistics = "1.6"
StatsBase = "0.31, 0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.10.2"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
9 changes: 2 additions & 7 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,5 @@ function LogDensityProblems.capabilities(::Type{typeof(ℓπ_gdemo)})
return LogDensityProblems.LogDensityOrder{0}()
end

test_show(x) = test_show(s -> length(s) > 0, x)
function test_show(pred, x)
io = IOBuffer(; append=true)
show(io, x)
s = read(io, String)
@test pred(s)
end
test_show(x) = test_show(!isempty, x)
test_show(pred, x) = @test pred(repr(x))
26 changes: 10 additions & 16 deletions test/sampler-vec.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ReTest, AdvancedHMC, LinearAlgebra, UnicodePlots, Random
using ReTest, AdvancedHMC, LinearAlgebra, Random
using Statistics: mean, var, cov

@testset "sample (vectorized)" begin
Expand Down Expand Up @@ -95,7 +95,7 @@ using Statistics: mean, var, cov
end

# Time for multiple runs of single chain
time_seperate = Vector{Float64}(undef, n_chains_max)
time_separate = Vector{Float64}(undef, n_chains_max)

for (i, n_chains) in enumerate(n_chains_list)
t = @elapsed for j in 1:n_chains
Expand All @@ -104,22 +104,16 @@ using Statistics: mean, var, cov
h, κ, θ_init_list[i][:, j], n_samples; verbose=false
)
end
time_seperate[i] = t
time_separate[i] = t
end

# Make plot
fig = lineplot(
collect(1:n_chains_max),
time_mat;
title="Scalabiliry of multiple chains",
name="vectorization",
xlabel="Num of chains",
ylabel="Time (s)",
println("\nVectorized vs separate sampling")
println(" number of chains: ", n_chains_list)
println(" elapsed time [s] (vectorized): ", round.(time_mat; sigdigits=2))
println(" elapsed time [s] (separate): ", round.(time_separate; sigdigits=2))
println(
" ratio of elapsed time: ",
round.(time_separate ./ time_mat; sigdigits=2),
)
lineplot!(fig, collect(n_chains_list), time_seperate; color=:blue, name="seperate")
println()
show(fig)
println()
println()
end
end
3 changes: 1 addition & 2 deletions test/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Allow pass --progress when running this script individually to turn on progress meter
const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using ReTest, AdvancedHMC, LinearAlgebra, Random, Plots
using ReTest, AdvancedHMC, LinearAlgebra, Random
using AdvancedHMC: StaticTerminationCriterion, DynamicTerminationCriterion
using Setfield
using Statistics: mean, var, cov
unicodeplots()

function test_stats(
::Trajectory{TS,I,TC}, stats, n_adapts
Expand Down
Loading