Skip to content

Commit 175c17c

Browse files
authored
Use invokelatest in tests (because of Julia 1.12) (#404)
Julia 1.12 makes some changes to how world age is handled. JuliaBUGS use `eval` to generate node functions and this is problematic under 1.12. If a user uses JuliaBUGS `compile` in REPL, there is no problem, like before. This issue is particularly pronounced when using JuliaBUGS with a pattern like `compile -> LogDensityProblems.logdensity` within the same script. It can be pretty involved fixing the problem fundamentally. For now, this PR patches some of the tests to use `invokelatest` so they run. Other than it's nontrivial to fix the problem, we are likely seeing a change to the AD interface for JuliaBUGS following #397.
1 parent d36cbd3 commit 175c17c

File tree

13 files changed

+238
-139
lines changed

13 files changed

+238
-139
lines changed

.github/workflows/Benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
- name: Set up Julia
4141
uses: julia-actions/setup-julia@v2
4242
with:
43-
version: '1'
43+
version: '1.11'
4444
arch: ${{ matrix.arch }}
4545

4646
- uses: actions/cache@v4

.github/workflows/Coverage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
matrix:
2222
version:
23-
- '1'
23+
- '1.11'
2424
os:
2525
- ubuntu-latest
2626
arch:

.github/workflows/Tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
fail-fast: false
2424
matrix:
2525
version:
26+
- '1.11'
2627
- '1'
2728
# - 'min'
2829
# - 'pre'
@@ -78,6 +79,7 @@ jobs:
7879
fail-fast: false
7980
matrix:
8081
version:
82+
- '1.11'
8183
- '1'
8284
os:
8385
- ubuntu-latest
@@ -127,6 +129,7 @@ jobs:
127129
fail-fast: false
128130
matrix:
129131
version:
132+
- '1.11'
130133
- '1'
131134
os:
132135
- ubuntu-latest

.github/workflows/TestsMacOS.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ jobs:
2222
strategy:
2323
fail-fast: false
2424
matrix:
25-
version:
25+
version:
26+
- '1.11'
2627
- '1'
2728
# - 'min'
2829
# - 'pre'
@@ -75,6 +76,7 @@ jobs:
7576
fail-fast: false
7677
matrix:
7778
version:
79+
- '1.11'
7880
- '1'
7981
os: [macOS-latest]
8082
arch: [aarch64]
@@ -122,6 +124,7 @@ jobs:
122124
fail-fast: false
123125
matrix:
124126
version:
127+
- '1.11'
125128
- '1'
126129
os: [macOS-latest]
127130
arch: [aarch64]

JuliaBUGS/benchmark/benchmark.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,9 @@ function _create_results_dataframe(results::OrderedDict{Symbol,BenchmarkResult})
8787
end
8888

8989
function _print_results_table(
90-
results::OrderedDict{Symbol,BenchmarkResult}; backend=Val(:text)
90+
results::OrderedDict{Symbol,BenchmarkResult}; backend::Symbol=:text
9191
)
9292
df = _create_results_dataframe(results)
93-
return pretty_table(
94-
df;
95-
header=["Model", "Parameters", "Density Time (µs)", "Density+Gradient Time (µs)"],
96-
backend=backend,
97-
)
93+
rename!(df, ["Model", "Parameters", "Density Time (µs)", "Density+Gradient Time (µs)"])
94+
return pretty_table(df; backend=backend)
9895
end

JuliaBUGS/benchmark/run_benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
4545
end
4646

4747
println("### Stan results:")
48-
_print_results_table(stan_results; backend=Val(:markdown))
48+
_print_results_table(stan_results; backend=:markdown)
4949

5050
println("### JuliaBUGS Mooncake results:")
51-
_print_results_table(juliabugs_results; backend=Val(:markdown))
51+
_print_results_table(juliabugs_results; backend=:markdown)

JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
end
88
data = (mu=[0, 0], sigma=[1 0; 0 1])
99
model = compile(model_def, data)
10-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
10+
ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true))
1111
n_samples, n_adapts = 10, 0
1212
D = LogDensityProblems.dimension(model)
1313
initial_θ = rand(D)
14-
samples_and_stats = AbstractMCMC.sample(
14+
samples_and_stats = Base.invokelatest(
15+
AbstractMCMC.sample,
1516
StableRNG(1234),
1617
ad_model,
1718
NUTS(0.8),
@@ -33,14 +34,15 @@
3334
JuliaBUGS.BUGSExamples, example
3435
)
3536
model = JuliaBUGS.compile(model_def, data, inits)
36-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
37+
ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true))
3738

3839
n_samples, n_adapts = 1000, 1000
3940

4041
D = LogDensityProblems.dimension(model)
41-
initial_θ = JuliaBUGS.getparams(model)
42+
initial_θ = Base.invokelatest(JuliaBUGS.getparams, model)
4243

43-
samples_and_stats = AbstractMCMC.sample(
44+
samples_and_stats = Base.invokelatest(
45+
AbstractMCMC.sample,
4446
StableRNG(1234),
4547
ad_model,
4648
NUTS(0.8),

JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@
2727
)
2828

2929
model = compile(model_def, data, (;))
30-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
30+
ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true))
3131
n_samples, n_adapts = 2000, 1000
3232

3333
D = LogDensityProblems.dimension(model)
3434
initial_θ = rand(D)
3535

36-
hmc_chain = AbstractMCMC.sample(
36+
hmc_chain = Base.invokelatest(
37+
AbstractMCMC.sample,
3738
ad_model,
3839
NUTS(0.8),
3940
n_samples;
@@ -72,7 +73,8 @@
7273

7374
n_samples, n_adapts = 20000, 5000
7475

75-
mh_chain = AbstractMCMC.sample(
76+
mh_chain = Base.invokelatest(
77+
AbstractMCMC.sample,
7678
model,
7779
RWMH(MvNormal(zeros(D), I)),
7880
n_samples;
@@ -108,9 +110,9 @@
108110
sigma[3] ~ InverseGamma(2, 3)
109111
end
110112
model = compile(model_def, (;))
111-
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
112-
hmc_chain = AbstractMCMC.sample(
113-
ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains
113+
ad_model = Base.invokelatest(ADgradient, :ReverseDiff, model; compile=Val(true))
114+
hmc_chain = Base.invokelatest(
115+
AbstractMCMC.sample, ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains
114116
)
115117
@test Set(hmc_chain.name_map[:parameters]) == Set([
116118
Symbol("sigma[3]"),

JuliaBUGS/test/gibbs.jl

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,16 @@ using StatsBase: mode
165165
model = compile(model_def, data, (;))
166166

167167
# single step
168-
p_s, st_init = AbstractMCMC.step(
168+
p_s, st_init = Base.invokelatest(
169+
AbstractMCMC.step,
169170
Random.default_rng(),
170171
AbstractMCMC.LogDensityModel(model),
171172
Gibbs(model, IndependentMH()),
172173
)
173174

174175
# following step
175-
p_s, st = AbstractMCMC.step(
176+
p_s, st = Base.invokelatest(
177+
AbstractMCMC.step,
176178
Random.default_rng(),
177179
AbstractMCMC.LogDensityModel(model),
178180
Gibbs(model, IndependentMH()),
@@ -230,7 +232,9 @@ using StatsBase: mode
230232

231233
# Test that sampling runs without error
232234
rng = Random.MersenneTwister(123)
233-
chain = sample(rng, model, gibbs, 100; progress=false, chain_type=Chains)
235+
chain = Base.invokelatest(
236+
sample, rng, model, gibbs, 100; progress=false, chain_type=Chains
237+
)
234238

235239
@test chain isa AbstractMCMC.AbstractChains
236240
@test size(chain, 1) == 100 # Number of samples
@@ -255,7 +259,8 @@ using StatsBase: mode
255259
# especially in a single-parameter model where Gibbs reduces to plain MH
256260
rng = Random.MersenneTwister(42)
257261
gibbs = Gibbs(model, IndependentMH())
258-
chain = sample(
262+
chain = Base.invokelatest(
263+
sample,
259264
rng,
260265
model,
261266
gibbs,
@@ -312,7 +317,8 @@ using StatsBase: mode
312317
# Sample with Gibbs - need more samples for IndependentMH
313318
rng = Random.MersenneTwister(42)
314319
gibbs = Gibbs(model, IndependentMH())
315-
chain = sample(
320+
chain = Base.invokelatest(
321+
sample,
316322
rng,
317323
model,
318324
gibbs,
@@ -350,7 +356,8 @@ using StatsBase: mode
350356
# Sample with Gibbs
351357
rng1 = Random.MersenneTwister(789)
352358
gibbs = Gibbs(model, IndependentMH())
353-
chain_gibbs = sample(
359+
chain_gibbs = Base.invokelatest(
360+
sample,
354361
rng1,
355362
model,
356363
gibbs,
@@ -407,7 +414,8 @@ using StatsBase: mode
407414
gibbs = Gibbs(model, sampler_map)
408415

409416
rng = Random.MersenneTwister(999)
410-
chain = sample(
417+
chain = Base.invokelatest(
418+
sample,
411419
rng,
412420
model,
413421
gibbs,
@@ -448,11 +456,15 @@ using StatsBase: mode
448456

449457
# Take one step
450458
rng = Random.MersenneTwister(123)
451-
env1, state = AbstractMCMC.step(
452-
rng, AbstractMCMC.LogDensityModel(model_init), gibbs
459+
env1, state = Base.invokelatest(
460+
AbstractMCMC.step, rng, AbstractMCMC.LogDensityModel(model_init), gibbs
453461
)
454-
env2, _ = AbstractMCMC.step(
455-
rng, AbstractMCMC.LogDensityModel(model_init), gibbs, state
462+
env2, _ = Base.invokelatest(
463+
AbstractMCMC.step,
464+
rng,
465+
AbstractMCMC.LogDensityModel(model_init),
466+
gibbs,
467+
state,
456468
)
457469

458470
# When updating a, b and c should remain fixed in that sub-step
@@ -541,8 +553,8 @@ using StatsBase: mode
541553
gibbs_invalid = Gibbs(model, sampler_map_invalid)
542554

543555
rng = Random.MersenneTwister(123)
544-
@test_throws ErrorException sample(
545-
rng, model, gibbs_invalid, 10; progress=false, chain_type=Chains
556+
@test_throws ErrorException Base.invokelatest(
557+
sample, rng, model, gibbs_invalid, 10; progress=false, chain_type=Chains
546558
)
547559

548560
# Also test with NUTS
@@ -551,8 +563,8 @@ using StatsBase: mode
551563
@varname(σ) => IndependentMH(),
552564
)
553565
gibbs_nuts = Gibbs(model, sampler_map_nuts)
554-
@test_throws ErrorException sample(
555-
rng, model, gibbs_nuts, 10; progress=false, chain_type=Chains
566+
@test_throws ErrorException Base.invokelatest(
567+
sample, rng, model, gibbs_nuts, 10; progress=false, chain_type=Chains
556568
)
557569
end
558570

@@ -565,7 +577,9 @@ using StatsBase: mode
565577
gibbs2 = Gibbs(model, sampler_map2)
566578

567579
rng = Random.MersenneTwister(456)
568-
chain2 = sample(rng, model, gibbs2, 50; progress=false, chain_type=Chains)
580+
chain2 = Base.invokelatest(
581+
sample, rng, model, gibbs2, 50; progress=false, chain_type=Chains
582+
)
569583

570584
@test chain2 isa AbstractMCMC.AbstractChains
571585
@test size(chain2, 1) == 50
@@ -613,7 +627,8 @@ using StatsBase: mode
613627
model_init = initialize!(model, init_params)
614628

615629
rng = Random.MersenneTwister(789)
616-
chain = sample(
630+
chain = Base.invokelatest(
631+
sample,
617632
rng,
618633
model_init,
619634
gibbs,
@@ -664,7 +679,9 @@ using StatsBase: mode
664679
gibbs = Gibbs(model, sampler_map)
665680

666681
rng = StableRNG(1234)
667-
chain = sample(rng, model, gibbs, 1000; progress=false, chain_type=Chains)
682+
chain = Base.invokelatest(
683+
sample, rng, model, gibbs, 1000; progress=false, chain_type=Chains
684+
)
668685

669686
@test chain isa AbstractMCMC.AbstractChains
670687
@test size(chain, 1) == 1000
@@ -699,7 +716,9 @@ using StatsBase: mode
699716
gibbs = Gibbs(model, sampler_map)
700717

701718
rng = Random.MersenneTwister(123)
702-
chain = sample(rng, model, gibbs, 200; progress=false, chain_type=Chains)
719+
chain = Base.invokelatest(
720+
sample, rng, model, gibbs, 200; progress=false, chain_type=Chains
721+
)
703722

704723
@test chain isa AbstractMCMC.AbstractChains
705724
@test size(chain, 1) == 200
@@ -767,7 +786,8 @@ using StatsBase: mode
767786
model_init = initialize!(model, init_params)
768787

769788
rng = Random.MersenneTwister(789)
770-
chain = sample(
789+
chain = Base.invokelatest(
790+
sample,
771791
rng,
772792
model_init,
773793
gibbs,
@@ -828,14 +848,18 @@ using StatsBase: mode
828848

829849
# Manually step through to inspect states
830850
logdensitymodel = AbstractMCMC.LogDensityModel(model)
831-
val, state = AbstractMCMC.step(rng, logdensitymodel, gibbs; model=model)
851+
val, state = Base.invokelatest(
852+
AbstractMCMC.step, rng, logdensitymodel, gibbs; model=model
853+
)
832854

833855
# Initial state should have empty sub_states
834856
@test isempty(state.sub_states)
835857

836858
# Step a few times
837859
for i in 1:3
838-
val, state = AbstractMCMC.step(rng, logdensitymodel, gibbs, state; model=model)
860+
val, state = Base.invokelatest(
861+
AbstractMCMC.step, rng, logdensitymodel, gibbs, state; model=model
862+
)
839863
end
840864

841865
# After stepping, HMC samplers should have preserved states
@@ -846,7 +870,7 @@ using StatsBase: mode
846870
@test !haskey(state.sub_states, [@varname(β)])
847871

848872
# Verify that the sampler still works correctly
849-
chain = sample(rng, model, gibbs, 100; progress=false)
873+
chain = Base.invokelatest(sample, rng, model, gibbs, 100; progress=false)
850874
@test length(chain) == 100
851875
end
852876
end

0 commit comments

Comments
 (0)