Skip to content

Commit 11319c0

Browse files
committed
RangeLinkedValAcc
1 parent d7da26d commit 11319c0

File tree

3 files changed

+128
-96
lines changed

3 files changed

+128
-96
lines changed

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@ use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
88
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false
99

1010
function DynamicPPL.tweak_adtype(
11-
ad::ADTypes.AutoForwardDiff{chunk_size},
12-
::DynamicPPL.Model,
13-
vi::DynamicPPL.AbstractVarInfo,
11+
ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, params::AbstractVector
1412
) where {chunk_size}
15-
params = vi[:]
16-
1713
# Use DynamicPPL tag to improve stack traces
1814
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
1915
# NOTE: DifferentiationInterface disables tag checking if the

src/logdensityfunction.jl

Lines changed: 124 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -157,38 +157,44 @@ struct LogDensityFunction{
157157
_adprep::ADP
158158
_dim::Int
159159

160+
"""
161+
function LogDensityFunction(
162+
model::Model,
163+
getlogdensity::Function=getlogjoint_internal,
164+
link::Union{Bool,Set{VarName}}=false;
165+
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
166+
)
167+
168+
Generate a `LogDensityFunction` for the given model.
169+
170+
The `link` argument specifies which VarNames in the model should be linked. This can
171+
either be a Bool (if `link=true` all variables are linked; if `link=false` all variables
172+
are unlinked); or a `Set{VarName}` specifying exactly which variables should be linked.
173+
Any sub-variables of the set's elements will be linked.
174+
"""
160175
function LogDensityFunction(
161176
model::Model,
162177
getlogdensity::Function=getlogjoint_internal,
163-
varinfo::AbstractVarInfo=VarInfo(model);
178+
link::Union{Bool,Set{VarName}}=false;
164179
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
165180
)
166-
# Figure out which variable corresponds to which index, and
167-
# which variables are linked.
168-
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
169-
# Figure out if all variables are linked, unlinked, or mixed
170-
link_statuses = Bool[]
171-
for ral in all_iden_ranges
172-
push!(link_statuses, ral.is_linked)
173-
end
174-
for (_, ral) in all_ranges
175-
push!(link_statuses, ral.is_linked)
176-
end
177-
Tlink = if all(link_statuses)
178-
true
179-
elseif all(!s for s in link_statuses)
180-
false
181-
else
182-
nothing
183-
end
184-
x = [val for val in varinfo[:]]
181+
# Run the model once to determine variable ranges and linking. Because the
182+
# parameters stored in the LogDensityFunction are never used, we can just use
183+
# InitFromPrior to create new values. The actual values don't matter, only the
184+
# length, since that's used for gradient prep.
185+
vi = OnlyAccsVarInfo(AccumulatorTuple((RangeLinkedValueAcc(link),)))
186+
_, vi = DynamicPPL.init!!(model, vi, InitFromPrior())
187+
rlvacc = first(vi.accs)
188+
Tlink, all_iden_ranges, all_ranges, x = get_data(rlvacc)
189+
@show Tlink, all_iden_ranges, all_ranges, x
190+
# That gives us all the information we need to create the LogDensityFunction.
185191
dim = length(x)
186192
# Do AD prep if needed
187193
prep = if adtype === nothing
188194
nothing
189195
else
190196
# Make backend-specific tweaks to the adtype
191-
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
197+
adtype = DynamicPPL.tweak_adtype(adtype, model, x)
192198
DI.prepare_gradient(
193199
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
194200
adtype,
@@ -293,7 +299,7 @@ end
293299
tweak_adtype(
294300
adtype::ADTypes.AbstractADType,
295301
model::Model,
296-
varinfo::AbstractVarInfo,
302+
params::AbstractVector
297303
)
298304
299305
Return an 'optimised' form of the adtype. This is useful for doing
@@ -304,79 +310,108 @@ model.
304310
305311
By default, this just returns the input unchanged.
306312
"""
307-
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype
313+
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype
308314

309-
######################################################
310-
# Helper functions to extract ranges and link status #
311-
######################################################
315+
##############################
316+
# RangeLinkedVal accumulator #
317+
##############################
312318

313-
# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The
314-
# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges
315-
# and link status. So there is no motivation to use SimpleVarInfo inside a
316-
# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue
317-
# that there is no purpose in supporting untyped VarInfo either.
318-
"""
319-
get_ranges_and_linked(varinfo::VarInfo)
320-
321-
Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter
322-
representation, along with whether each variable is linked or unlinked.
323-
324-
This function should return a tuple containing:
319+
struct RangeLinkedValueAcc{L<:Union{Bool,Set{VarName}},N<:NamedTuple} <: AbstractAccumulator
320+
should_link::L
321+
current_index::Int
322+
iden_varname_ranges::N
323+
varname_ranges::Dict{VarName,RangeAndLinked}
324+
values::Vector{Any}
325+
end
326+
function RangeLinkedValueAcc(should_link::Union{Bool,Set{VarName}})
327+
return RangeLinkedValueAcc(should_link, 1, (;), Dict{VarName,RangeAndLinked}(), Any[])
328+
end
325329

326-
- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked`
327-
- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`.
328-
"""
329-
function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms}
330-
all_iden_ranges = NamedTuple()
331-
all_ranges = Dict{VarName,RangeAndLinked}()
332-
offset = 1
333-
for sym in syms
334-
md = varinfo.metadata[sym]
335-
this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset)
336-
all_iden_ranges = merge(all_iden_ranges, this_md_iden)
337-
all_ranges = merge(all_ranges, this_md_others)
330+
function get_data(rlvacc::RangeLinkedValueAcc)
331+
link_statuses = Bool[]
332+
for ral in rlvacc.iden_varname_ranges
333+
push!(link_statuses, ral.is_linked)
338334
end
339-
return all_iden_ranges, all_ranges
340-
end
341-
function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}})
342-
all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1)
343-
return all_iden, all_others
335+
for (_, ral) in rlvacc.varname_ranges
336+
push!(link_statuses, ral.is_linked)
337+
end
338+
Tlink = if all(link_statuses)
339+
true
340+
elseif all(!s for s in link_statuses)
341+
false
342+
else
343+
nothing
344+
end
345+
return (
346+
Tlink, rlvacc.iden_varname_ranges, rlvacc.varname_ranges, [v for v in rlvacc.values]
347+
)
344348
end
345-
function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int)
346-
all_iden_ranges = NamedTuple()
347-
all_ranges = Dict{VarName,RangeAndLinked}()
348-
offset = start_offset
349-
for (vn, idx) in md.idcs
350-
is_linked = md.is_transformed[idx]
351-
range = md.ranges[idx] .+ (start_offset - 1)
352-
if AbstractPPL.getoptic(vn) === identity
353-
all_iden_ranges = merge(
354-
all_iden_ranges,
355-
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
356-
)
357-
else
358-
all_ranges[vn] = RangeAndLinked(range, is_linked)
359-
end
360-
offset += length(range)
349+
350+
accumulator_name(::Type{<:RangeLinkedValueAcc}) = :RangeLinkedValueAcc
351+
accumulate_observe!!(acc::RangeLinkedValueAcc, dist, val, vn) = acc
352+
function accumulate_assume!!(
353+
acc::RangeLinkedValueAcc, val, logjac, vn::VarName{sym}, dist::Distribution
354+
) where {sym}
355+
link_this_vn = if acc.should_link isa Bool
356+
acc.should_link
357+
else
358+
# Set{VarName}
359+
any(should_link_vn -> subsumes(should_link_vn, vn), acc.should_link)
360+
end
361+
val = if link_this_vn
362+
to_linked_vec_transform(dist)(val)
363+
else
364+
to_vec_transform(dist)(val)
365+
end
366+
new_values = vcat(acc.values, val)
367+
len = length(val)
368+
range = (acc.current_index):(acc.current_index + len - 1)
369+
ral = RangeAndLinked(range, link_this_vn)
370+
iden_varnames, other_varnames = if getoptic(vn) === identity
371+
merge(acc.iden_varname_ranges, (sym => ral,)), acc.varname_ranges
372+
else
373+
acc.varname_ranges[vn] = ral
374+
acc.iden_varname_ranges, acc.varname_ranges
361375
end
362-
return all_iden_ranges, all_ranges, offset
376+
return RangeLinkedValueAcc(
377+
acc.should_link, acc.current_index + len, iden_varnames, other_varnames, new_values
378+
)
363379
end
364-
function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int)
365-
all_iden_ranges = NamedTuple()
366-
all_ranges = Dict{VarName,RangeAndLinked}()
367-
offset = start_offset
368-
for (vn, idx) in vnv.varname_to_index
369-
is_linked = vnv.is_unconstrained[idx]
370-
range = vnv.ranges[idx] .+ (start_offset - 1)
371-
if AbstractPPL.getoptic(vn) === identity
372-
all_iden_ranges = merge(
373-
all_iden_ranges,
374-
NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)),
375-
)
376-
else
377-
all_ranges[vn] = RangeAndLinked(range, is_linked)
378-
end
379-
offset += length(range)
380+
function Base.copy(acc::RangeLinkedValueAcc)
381+
return RangeLinkedValueAcc(
382+
acc.should_link,
383+
acc.current_index,
384+
acc.iden_varname_ranges,
385+
copy(acc.varname_ranges),
386+
copy(acc.values),
387+
)
388+
end
389+
_zero(acc::RangeLinkedValueAcc) = RangeLinkedValueAcc(acc.should_link)
390+
reset(acc::RangeLinkedValueAcc) = _zero(acc)
391+
split(acc::RangeLinkedValueAcc) = _zero(acc)
392+
function combine(acc1::RangeLinkedValueAcc, acc2::RangeLinkedValueAcc)
393+
new_values = vcat(acc1.values, acc2.values)
394+
new_current_index = acc1.current_index + acc2.current_index - 1
395+
acc2_iden_varnames_shifted = NamedTuple(
396+
k => RangeAndLinked((ral.range .+ (acc1.current_index - 1)), ral.is_linked) for
397+
(k, ral) in pairs(acc2.iden_varname_ranges)
398+
)
399+
new_iden_varname_ranges = merge(acc1.iden_varname_ranges, acc2_iden_varnames_shifted)
400+
acc2_varname_ranges_shifted = Dict{VarName,RangeAndLinked}()
401+
for (k, ral) in acc2.varname_ranges
402+
acc2_varname_ranges_shifted[k] = RangeAndLinked(
403+
(ral.range .+ (acc1.current_index - 1)), ral.is_linked
404+
)
380405
end
381-
return all_iden_ranges, all_ranges, offset
406+
new_varname_ranges = merge(acc1.varname_ranges, acc2_varname_ranges_shifted)
407+
return RangeLinkedValueAcc(
408+
# TODO: using acc1.should_link is not really 'correct', but `should_link` only
409+
# affects model evaluation and `combine` only runs at the end of model evaluation,
410+
# so it shouldn't matter
411+
acc1.should_link,
412+
new_current_index,
413+
new_iden_varname_ranges,
414+
new_varname_ranges,
415+
new_values,
416+
)
382417
end

test/ext/DynamicPPLForwardDiffExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@ using Test: @test, @testset
1414
@model f() = x ~ MvNormal(zeros(MODEL_SIZE), I)
1515
model = f()
1616
varinfo = VarInfo(model)
17+
x = varinfo[:]
1718

1819
@testset "Chunk size setting" for chunksize in (nothing, 0)
1920
base_adtype = AutoForwardDiff(; chunksize=chunksize)
20-
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo)
21+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x)
2122
@test new_adtype isa AutoForwardDiff{MODEL_SIZE}
2223
end
2324

2425
@testset "Tag setting" begin
2526
base_adtype = AutoForwardDiff()
26-
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo)
27+
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, x)
2728
@test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag}
2829
end
2930
end

0 commit comments

Comments
 (0)