@@ -131,42 +131,44 @@ struct LogDensityFunction{
131131 adtype:: AD
132132 " (internal use only) gradient preparation object for the model"
133133 prep:: Union{Nothing,DI.GradientPrep}
134+ end
134135
135- function LogDensityFunction (
136- model:: Model ,
137- getlogdensity:: Function = getlogjoint_internal,
138- varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
139- adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
140- )
141- if adtype === nothing
142- prep = nothing
136+ function LogDensityFunction (
137+ model:: Model ,
138+ getlogdensity:: Function = getlogjoint_internal,
139+ varinfo:: AbstractVarInfo = ldf_default_varinfo (model, getlogdensity);
140+ adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
141+ )
142+ if adtype === nothing
143+ prep = nothing
144+ else
145+ # Make backend-specific tweaks to the adtype
146+ adtype = tweak_adtype (adtype, model, varinfo)
147+ # Check whether it is supported
148+ is_supported (adtype) ||
149+ @warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
150+ # Get a set of dummy params to use for prep
151+ x = [val for val in varinfo[:]]
152+ if use_closure (adtype)
153+ prep = DI. prepare_gradient (
154+ LogDensityAt (model, getlogdensity, varinfo), adtype, x
155+ )
143156 else
144- # Make backend-specific tweaks to the adtype
145- adtype = tweak_adtype (adtype, model, varinfo)
146- # Check whether it is supported
147- is_supported (adtype) ||
148- @warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
149- # Get a set of dummy params to use for prep
150- x = [val for val in varinfo[:]]
151- if use_closure (adtype)
152- prep = DI. prepare_gradient (
153- LogDensityAt (model, getlogdensity, varinfo), adtype, x
154- )
155- else
156- prep = DI. prepare_gradient (
157- logdensity_at,
158- adtype,
159- x,
160- DI. Constant (model),
161- DI. Constant (getlogdensity),
162- DI. Constant (varinfo),
163- )
164- end
157+ prep = DI. prepare_gradient (
158+ logdensity_at,
159+ adtype,
160+ x,
161+ DI. Constant (model),
162+ DI. Constant (getlogdensity),
163+ DI. Constant (varinfo),
164+ )
165165 end
166- return new {typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)} (
167- model, getlogdensity, varinfo, adtype, prep
168- )
169166 end
167+ return LogDensityFunction{
168+ typeof (model),typeof (getlogdensity),typeof (varinfo),typeof (adtype)
169+ }(
170+ model, getlogdensity, varinfo, adtype, prep
171+ )
170172end
171173
172174"""
0 commit comments