Skip to content

Commit ed1db96

Browse files
committed
fix make the constructor of LogDensityFunction implicit
1 parent 052bc19 commit ed1db96

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

src/logdensityfunction.jl

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -131,42 +131,44 @@ struct LogDensityFunction{
131131
adtype::AD
132132
"(internal use only) gradient preparation object for the model"
133133
prep::Union{Nothing,DI.GradientPrep}
134+
end
134135

135-
function LogDensityFunction(
136-
model::Model,
137-
getlogdensity::Function=getlogjoint_internal,
138-
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
139-
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
140-
)
141-
if adtype === nothing
142-
prep = nothing
136+
function LogDensityFunction(
137+
model::Model,
138+
getlogdensity::Function=getlogjoint_internal,
139+
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
140+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
141+
)
142+
if adtype === nothing
143+
prep = nothing
144+
else
145+
# Make backend-specific tweaks to the adtype
146+
adtype = tweak_adtype(adtype, model, varinfo)
147+
# Check whether it is supported
148+
is_supported(adtype) ||
149+
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
150+
# Get a set of dummy params to use for prep
151+
x = [val for val in varinfo[:]]
152+
if use_closure(adtype)
153+
prep = DI.prepare_gradient(
154+
LogDensityAt(model, getlogdensity, varinfo), adtype, x
155+
)
143156
else
144-
# Make backend-specific tweaks to the adtype
145-
adtype = tweak_adtype(adtype, model, varinfo)
146-
# Check whether it is supported
147-
is_supported(adtype) ||
148-
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
149-
# Get a set of dummy params to use for prep
150-
x = [val for val in varinfo[:]]
151-
if use_closure(adtype)
152-
prep = DI.prepare_gradient(
153-
LogDensityAt(model, getlogdensity, varinfo), adtype, x
154-
)
155-
else
156-
prep = DI.prepare_gradient(
157-
logdensity_at,
158-
adtype,
159-
x,
160-
DI.Constant(model),
161-
DI.Constant(getlogdensity),
162-
DI.Constant(varinfo),
163-
)
164-
end
157+
prep = DI.prepare_gradient(
158+
logdensity_at,
159+
adtype,
160+
x,
161+
DI.Constant(model),
162+
DI.Constant(getlogdensity),
163+
DI.Constant(varinfo),
164+
)
165165
end
166-
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}(
167-
model, getlogdensity, varinfo, adtype, prep
168-
)
169166
end
167+
return LogDensityFunction{
168+
typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)
169+
}(
170+
model, getlogdensity, varinfo, adtype, prep
171+
)
170172
end
171173

172174
"""

0 commit comments

Comments
 (0)