diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7c7438c9f..84305feb2 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -131,42 +131,44 @@ struct LogDensityFunction{ adtype::AD "(internal use only) gradient preparation object for the model" prep::Union{Nothing,DI.GradientPrep} +end - function LogDensityFunction( - model::Model, - getlogdensity::Function=getlogjoint_internal, - varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) - if adtype === nothing - prep = nothing +function LogDensityFunction( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, +) + if adtype === nothing + prep = nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo) + # Check whether it is supported + is_supported(adtype) || + @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." + # Get a set of dummy params to use for prep + x = [val for val in varinfo[:]] + if use_closure(adtype) + prep = DI.prepare_gradient( + LogDensityAt(model, getlogdensity, varinfo), adtype, x + ) else - # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo) - # Check whether it is supported - is_supported(adtype) || - @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." - # Get a set of dummy params to use for prep - x = [val for val in varinfo[:]] - if use_closure(adtype) - prep = DI.prepare_gradient( - LogDensityAt(model, getlogdensity, varinfo), adtype, x - ) - else - prep = DI.prepare_gradient( - logdensity_at, - adtype, - x, - DI.Constant(model), - DI.Constant(getlogdensity), - DI.Constant(varinfo), - ) - end + prep = DI.prepare_gradient( + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(getlogdensity), + DI.Constant(varinfo), + ) end - return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( - model, getlogdensity, varinfo, adtype, prep - ) end + return LogDensityFunction{ + typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype) + }( + model, getlogdensity, varinfo, adtype, prep + ) end """