@@ -157,38 +157,44 @@ struct LogDensityFunction{
157157 _adprep:: ADP
158158 _dim:: Int
159159
160+ """
161+ function LogDensityFunction(
162+ model::Model,
163+ getlogdensity::Function=getlogjoint_internal,
164+ link::Union{Bool,Set{VarName}}=false;
165+ adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
166+ )
167+
168+ Generate a `LogDensityFunction` for the given model.
169+
170+ The `link` argument specifies which VarNames in the model should be linked. This can
171+ either be a Bool (if `link=true` all variables are linked; if `link=false` all variables
172+ are unlinked); or a `Set{VarName}` specifying exactly which variables should be linked.
173+ Any sub-variables of the set's elements will be linked.
174+ """
160175 function LogDensityFunction (
161176 model:: Model ,
162177 getlogdensity:: Function = getlogjoint_internal,
163- varinfo :: AbstractVarInfo = VarInfo (model) ;
178+ link :: Union{Bool,Set{VarName}} = false ;
164179 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
165180 )
166- # Figure out which variable corresponds to which index, and
167- # which variables are linked.
168- all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
169- # Figure out if all variables are linked, unlinked, or mixed
170- link_statuses = Bool[]
171- for ral in all_iden_ranges
172- push! (link_statuses, ral. is_linked)
173- end
174- for (_, ral) in all_ranges
175- push! (link_statuses, ral. is_linked)
176- end
177- Tlink = if all (link_statuses)
178- true
179- elseif all (! s for s in link_statuses)
180- false
181- else
182- nothing
183- end
184- x = [val for val in varinfo[:]]
181+ # Run the model once to determine variable ranges and linking. Because the
182+ # parameters stored in the LogDensityFunction are never used, we can just use
183+ # InitFromPrior to create new values. The actual values don't matter, only the
184+ # length, since that's used for gradient prep.
185+ vi = OnlyAccsVarInfo (AccumulatorTuple ((RangeLinkedValueAcc (link),)))
186+ _, vi = DynamicPPL. init!! (model, vi, InitFromPrior ())
187+ rlvacc = first (vi. accs)
188+ Tlink, all_iden_ranges, all_ranges, x = get_data (rlvacc)
189+ @show Tlink, all_iden_ranges, all_ranges, x
190+ # That gives us all the information we need to create the LogDensityFunction.
185191 dim = length (x)
186192 # Do AD prep if needed
187193 prep = if adtype === nothing
188194 nothing
189195 else
190196 # Make backend-specific tweaks to the adtype
191- adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo )
197+ adtype = DynamicPPL. tweak_adtype (adtype, model, x )
192198 DI. prepare_gradient (
193199 LogDensityAt {Tlink} (model, getlogdensity, all_iden_ranges, all_ranges),
194200 adtype,
293299 tweak_adtype(
294300 adtype::ADTypes.AbstractADType,
295301 model::Model,
296- varinfo::AbstractVarInfo,
302+ params::AbstractVector
297303 )
298304
299305Return an 'optimised' form of the adtype. This is useful for doing
@@ -304,79 +310,108 @@ model.
304310
305311By default, this just returns the input unchanged.
306312"""
307- tweak_adtype (adtype:: ADTypes.AbstractADType , :: Model , :: AbstractVarInfo ) = adtype
313+ tweak_adtype (adtype:: ADTypes.AbstractADType , :: Model , :: AbstractVector ) = adtype
308314
309- # #####################################################
310- # Helper functions to extract ranges and link status #
311- # #####################################################
315+ # #############################
316+ # RangeLinkedVal accumulator #
317+ # #############################
312318
313- # This fails for SimpleVarInfo, but honestly there is no reason to support that here. The
314- # fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges
315- # and link status. So there is no motivation to use SimpleVarInfo inside a
316- # LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue
317- # that there is no purpose in supporting untyped VarInfo either.
318- """
319- get_ranges_and_linked(varinfo::VarInfo)
320-
321- Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter
322- representation, along with whether each variable is linked or unlinked.
323-
324- This function should return a tuple containing:
319+ struct RangeLinkedValueAcc{L<: Union{Bool,Set{VarName}} ,N<: NamedTuple } <: AbstractAccumulator
320+ should_link:: L
321+ current_index:: Int
322+ iden_varname_ranges:: N
323+ varname_ranges:: Dict{VarName,RangeAndLinked}
324+ values:: Vector{Any}
325+ end
326+ function RangeLinkedValueAcc (should_link:: Union{Bool,Set{VarName}} )
327+ return RangeLinkedValueAcc (should_link, 1 , (;), Dict {VarName,RangeAndLinked} (), Any[])
328+ end
325329
326- - A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked`
327- - A Dict mapping all other VarNames to their corresponding `RangeAndLinked`.
328- """
329- function get_ranges_and_linked (varinfo:: VarInfo{<:NamedTuple{syms}} ) where {syms}
330- all_iden_ranges = NamedTuple ()
331- all_ranges = Dict {VarName,RangeAndLinked} ()
332- offset = 1
333- for sym in syms
334- md = varinfo. metadata[sym]
335- this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata (md, offset)
336- all_iden_ranges = merge (all_iden_ranges, this_md_iden)
337- all_ranges = merge (all_ranges, this_md_others)
330+ function get_data (rlvacc:: RangeLinkedValueAcc )
331+ link_statuses = Bool[]
332+ for ral in rlvacc. iden_varname_ranges
333+ push! (link_statuses, ral. is_linked)
338334 end
339- return all_iden_ranges, all_ranges
340- end
341- function get_ranges_and_linked (varinfo:: VarInfo{<:Union{Metadata,VarNamedVector}} )
342- all_iden, all_others, _ = get_ranges_and_linked_metadata (varinfo. metadata, 1 )
343- return all_iden, all_others
335+ for (_, ral) in rlvacc. varname_ranges
336+ push! (link_statuses, ral. is_linked)
337+ end
338+ Tlink = if all (link_statuses)
339+ true
340+ elseif all (! s for s in link_statuses)
341+ false
342+ else
343+ nothing
344+ end
345+ return (
346+ Tlink, rlvacc. iden_varname_ranges, rlvacc. varname_ranges, [v for v in rlvacc. values]
347+ )
344348end
345- function get_ranges_and_linked_metadata (md:: Metadata , start_offset:: Int )
346- all_iden_ranges = NamedTuple ()
347- all_ranges = Dict {VarName,RangeAndLinked} ()
348- offset = start_offset
349- for (vn, idx) in md. idcs
350- is_linked = md. is_transformed[idx]
351- range = md. ranges[idx] .+ (start_offset - 1 )
352- if AbstractPPL. getoptic (vn) === identity
353- all_iden_ranges = merge (
354- all_iden_ranges,
355- NamedTuple ((AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked),)),
356- )
357- else
358- all_ranges[vn] = RangeAndLinked (range, is_linked)
359- end
360- offset += length (range)
349+
350+ accumulator_name (:: Type{<:RangeLinkedValueAcc} ) = :RangeLinkedValueAcc
351+ accumulate_observe!! (acc:: RangeLinkedValueAcc , dist, val, vn) = acc
352+ function accumulate_assume!! (
353+ acc:: RangeLinkedValueAcc , val, logjac, vn:: VarName{sym} , dist:: Distribution
354+ ) where {sym}
355+ link_this_vn = if acc. should_link isa Bool
356+ acc. should_link
357+ else
358+ # Set{VarName}
359+ any (should_link_vn -> subsumes (should_link_vn, vn), acc. should_link)
360+ end
361+ val = if link_this_vn
362+ to_linked_vec_transform (dist)(val)
363+ else
364+ to_vec_transform (dist)(val)
365+ end
366+ new_values = vcat (acc. values, val)
367+ len = length (val)
368+ range = (acc. current_index): (acc. current_index + len - 1 )
369+ ral = RangeAndLinked (range, link_this_vn)
370+ iden_varnames, other_varnames = if getoptic (vn) === identity
371+ merge (acc. iden_varname_ranges, (sym => ral,)), acc. varname_ranges
372+ else
373+ acc. varname_ranges[vn] = ral
374+ acc. iden_varname_ranges, acc. varname_ranges
361375 end
362- return all_iden_ranges, all_ranges, offset
376+ return RangeLinkedValueAcc (
377+ acc. should_link, acc. current_index + len, iden_varnames, other_varnames, new_values
378+ )
363379end
364- function get_ranges_and_linked_metadata (vnv:: VarNamedVector , start_offset:: Int )
365- all_iden_ranges = NamedTuple ()
366- all_ranges = Dict {VarName,RangeAndLinked} ()
367- offset = start_offset
368- for (vn, idx) in vnv. varname_to_index
369- is_linked = vnv. is_unconstrained[idx]
370- range = vnv. ranges[idx] .+ (start_offset - 1 )
371- if AbstractPPL. getoptic (vn) === identity
372- all_iden_ranges = merge (
373- all_iden_ranges,
374- NamedTuple ((AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked),)),
375- )
376- else
377- all_ranges[vn] = RangeAndLinked (range, is_linked)
378- end
379- offset += length (range)
380+ function Base. copy (acc:: RangeLinkedValueAcc )
381+ return RangeLinkedValueAcc (
382+ acc. should_link,
383+ acc. current_index,
384+ acc. iden_varname_ranges,
385+ copy (acc. varname_ranges),
386+ copy (acc. values),
387+ )
388+ end
389+ _zero (acc:: RangeLinkedValueAcc ) = RangeLinkedValueAcc (acc. should_link)
390+ reset (acc:: RangeLinkedValueAcc ) = _zero (acc)
391+ split (acc:: RangeLinkedValueAcc ) = _zero (acc)
392+ function combine (acc1:: RangeLinkedValueAcc , acc2:: RangeLinkedValueAcc )
393+ new_values = vcat (acc1. values, acc2. values)
394+ new_current_index = acc1. current_index + acc2. current_index - 1
395+ acc2_iden_varnames_shifted = NamedTuple (
396+ k => RangeAndLinked ((ral. range .+ (acc1. current_index - 1 )), ral. is_linked) for
397+ (k, ral) in pairs (acc2. iden_varname_ranges)
398+ )
399+ new_iden_varname_ranges = merge (acc1. iden_varname_ranges, acc2_iden_varnames_shifted)
400+ acc2_varname_ranges_shifted = Dict {VarName,RangeAndLinked} ()
401+ for (k, ral) in acc2. varname_ranges
402+ acc2_varname_ranges_shifted[k] = RangeAndLinked (
403+ (ral. range .+ (acc1. current_index - 1 )), ral. is_linked
404+ )
380405 end
381- return all_iden_ranges, all_ranges, offset
406+ new_varname_ranges = merge (acc1. varname_ranges, acc2_varname_ranges_shifted)
407+ return RangeLinkedValueAcc (
408+ # TODO : using acc1.should_link is not really 'correct', but `should_link` only
409+ # affects model evaluation and `combine` only runs at the end of model evaluation,
410+ # so it shouldn't matter
411+ acc1. should_link,
412+ new_current_index,
413+ new_iden_varname_ranges,
414+ new_varname_ranges,
415+ new_values,
416+ )
382417end
0 commit comments