-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
When sampling using the AbstractMCMC interface (as documented at https://turinglang.org/AdvancedHMC.jl/stable/get_started/#Using-the-AbstractMCMC-Interface), the Symbol values of progress (e.g. :perchain) described at https://turinglang.org/AbstractMCMC.jl/stable/api/#Progress-logging are not supported. Here's a MWE:
using AdvancedHMC, ForwardDiff, LogDensityProblems, LinearAlgebra, AbstractMCMC, LogDensityProblemsAD
struct LogTargetDensity
dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
function LogDensityProblems.capabilities(::Type{LogTargetDensity})
return LogDensityProblems.LogDensityOrder{0}()
end
ℓπ = LogTargetDensity(10)
model = AbstractMCMC.LogDensityModel(LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ))
n_samples, n_adapts = 2_000, 1_000
sampler = NUTS(0.8)
# sample with progress=true
AbstractMCMC.sample(model, sampler, n_adapts + n_samples; n_adapts, progress=true) # works
# sample with progress=:perchain
AbstractMCMC.sample(model, sampler, n_adapts + n_samples; n_adapts, progress=:perchain) # errorsERROR: TypeError: non-boolean (Symbol) used in boolean context
Stacktrace:
[1] sample(model_or_logdensity::AbstractMCMC.LogDensityModel{…}, sampler::NUTS{…}, N_or_isdone::Int64; kwargs::@Kwargs{…})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/mcqES/src/sample.jl:23
[2] top-level scope
@ REPL[38]:2
Some type information was truncated. Use `show(err)` to see complete types.Presumably this happens because lines like this constrain the type of progress to be a Bool:
AdvancedHMC.jl/src/abstractmcmc.jl
Line 59 in dc8dc1c
| progress::Bool=true, |
However, even if that wasn't the cause, the following lines override progress if the user didn't provide a callback (no clue why this is the case):
AdvancedHMC.jl/src/abstractmcmc.jl
Lines 72 to 75 in dc8dc1c
| if callback === nothing | |
| callback = HMCProgressCallback(N; progress=progress, verbose=verbose) | |
| progress = false # don't use AMCMC's progress-funtionality | |
| end |
Metadata
Metadata
Assignees
Labels
No labels