From 5a5a4e9ffdd6f1bc6ded4bb48a5e1256fd8999b6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 17:57:23 +0100 Subject: [PATCH 1/8] Remove `SamplingContext` for good --- docs/src/api.md | 12 +--- ext/DynamicPPLEnzymeCoreExt.jl | 2 - src/DynamicPPL.jl | 3 - src/context_implementations.jl | 106 ++------------------------------- src/contexts.jl | 69 +-------------------- src/debug_utils.jl | 2 +- src/sampler.jl | 45 -------------- src/simple_varinfo.jl | 19 ------ src/utils.jl | 44 -------------- test/Project.toml | 2 - test/ad.jl | 43 ------------- test/contexts.jl | 22 +------ test/debug_utils.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 6 ++ test/lkj.jl | 33 +++------- test/threadsafe.jl | 10 ++-- 16 files changed, 31 insertions(+), 389 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index d1dddb560..8c3444eb3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -447,12 +447,12 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext @@ -486,15 +486,7 @@ DynamicPPL.init ### Samplers -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. - -```@docs -SampleFromPrior -SampleFromUniform -``` - -Additionally, a generic sampler for inference is implemented. +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index d592e76b3..0088f8908 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6a01884a9..67a90cf48 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -96,13 +96,10 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..e38ffe6e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,38 +1,14 @@ # assume -""" - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` -""" -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - function tilde_assume(context::AbstractContext, args...) return tilde_assume(childcontext(context), args...) end function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi end - function tilde_assume(context::PrefixContext, right, vn, vi) # Note that we can't use something like this here: # new_vn = prefix(context, vn) @@ -46,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi) new_vn, new_context = prefix_and_strip_contexts(context, vn) return tilde_assume(new_context, right, new_vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi -) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) -end """ tilde_assume!!(context, right, vn, vi) @@ -71,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi) end # observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe!!(context.context, right, left, vi)`. -""" -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - function tilde_observe!!(context::AbstractContext, right, left, vn, vi) return tilde_observe!!(childcontext(context), right, left, vn, vi) end @@ -114,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, -) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi -end diff --git a/src/contexts.jl b/src/contexts.jl index cd9876768..8b5e866d0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..2ec8b15a2 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -485,7 +485,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/sampler.jl b/src/sampler.jl index 98b50ba55..8b49f6c3b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() - _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) - return new_vi, nothing -end - """ default_varinfo(rng, model, sampler) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 27365e4dc..f430755e7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -466,25 +466,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end diff --git a/src/utils.jl b/src/utils.jl index c7d1e089f..a4c5f4a1b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/test/Project.toml b/test/Project.toml index 91a885e96..5d860381d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -39,7 +38,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9, 0.10" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 0e5d8d7cf..23e676ee7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,49 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = DynamicPPL.link!!(VarInfo(model), model) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal, vi; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/contexts.jl b/test/contexts.jl index 1a6279bf4..2687c4336 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -24,8 +24,6 @@ using DynamicPPL: using LinearAlgebra: I using Random: Xoshiro -using EnzymeCore - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -150,11 +148,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -203,22 +201,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 692f53911..b34424a1c 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -64,6 +64,12 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) + # Check that the inferred varinfo is indeed suitable for evaluation + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f_eval, argtypes_eval) + # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..03e744b84 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -38,20 +33,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 0421c89e2..522730566 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From dde8b7e29ef1a51b429cba2c6185eceeaec13b85 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 18:05:29 +0100 Subject: [PATCH 2/8] Remove `tilde_assume` as well --- docs/src/api.md | 15 ++++++++------- src/DynamicPPL.jl | 5 +++-- src/context_implementations.jl | 23 ++++++++--------------- src/contexts.jl | 2 +- src/contexts/init.jl | 2 +- src/transforming.jl | 2 +- 6 files changed, 22 insertions(+), 27 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 8c3444eb3..e5c483bca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -344,6 +344,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -512,9 +519,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 67a90cf48..edf44439e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -103,8 +103,9 @@ export AbstractVarInfo, DefaultContext, PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, # Initialisation InitContext, AbstractInitStrategy, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e38ffe6e6..f25b63a64 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,18 +1,18 @@ # assume -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) +function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) + return tilde_assume!!(childcontext(context), right, vn, vi) end -function tilde_assume(::DefaultContext, right, vn, vi) +function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, right) x, inv_logjac = with_logabsdet_jacobian(f, y) vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) return x, vi end -function tilde_assume(context::PrefixContext, right, vn, vi) +function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) # Note that we can't use something like this here: # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) # This is because `prefix` applies _all_ prefixes in a given context to a # variable name. Thus, if we had two levels of nested prefixes e.g. # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the @@ -20,7 +20,7 @@ function tilde_assume(context::PrefixContext, right, vn, vi) # would apply the prefix `b._`, resulting in `b.a.b._`. # This is why we need a special function, `prefix_and_strip_contexts`. new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) + return tilde_assume!!(new_context, right, new_vn, vi) end """ @@ -28,16 +28,9 @@ end Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. """ -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end +function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) + return _evaluate!!(right, vi, context, vn) end # observe diff --git a/src/contexts.jl b/src/contexts.jl index 8b5e866d0..439da47e5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -185,7 +185,7 @@ PrefixContexts removed. NOTE: This does _not_ modify any variables in any `ConditionContext` and `FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline than `contextual_isassumption` and `contextual_isfixed` (the functions which actually use the `ConditionContext` and `FixedContext` values). Thus, by this time, any `ConditionContext`s and `FixedContext`s present have already served diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 636847117..7eea73f66 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -154,7 +154,7 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon end NodeTrait(::InitContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) diff --git a/src/transforming.jl b/src/transforming.jl index 56f861cff..3569d1502 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -12,7 +12,7 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. From d7c40338552f076a0e3676770254912b6191bd6d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 18 Jul 2025 17:42:13 +0100 Subject: [PATCH 3/8] Split up tilde_observe!! for Distribution / Submodel --- src/context_implementations.jl | 8 +++++--- src/transforming.jl | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f25b63a64..92200582c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -60,9 +60,11 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end + +function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +end diff --git a/src/transforming.jl b/src/transforming.jl index 3569d1502..5465b2ff2 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -13,7 +13,7 @@ struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, right, vn, vi + ::DynamicTransformationContext{isinverse}, right::Distribution, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +31,7 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) +function tilde_observe!!(::DynamicTransformationContext, right::Distribution, left, vn, vi) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end From 6c776e976ba7c6b44448e4de9e1c7063a7004c21 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 12:27:32 +0100 Subject: [PATCH 4/8] Tidy up tilde-pipeline methods and docstrings --- src/context_implementations.jl | 112 +++++++++++++++++++++++++-------- src/contexts/init.jl | 8 ++- 2 files changed, 92 insertions(+), 28 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 92200582c..a8f2d57e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,15 +1,34 @@ -# assume -function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo + ) + +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. + +`vn` is the VarName on the left-hand side of the tilde statement. +""" +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) return tilde_assume!!(childcontext(context), right, vn, vi) end -function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, right) x, inv_logjac = with_logabsdet_jacobian(f, y) vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) return x, vi end -function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) # Note that we can't use something like this here: # new_vn = prefix(context, vn) # return tilde_assume!!(childcontext(context), right, new_vn, vi) @@ -22,24 +41,62 @@ function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) new_vn, new_context = prefix_and_strip_contexts(context, vn) return tilde_assume!!(new_context, right, new_vn, vi) end - """ - tilde_assume!!(context, right, vn, vi) + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value and updated `vi`. +Evaluate the submodel with the given context. """ -function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo +) return _evaluate!!(right, vi, context, vn) end -# observe -function tilde_observe!!(context::AbstractContext, right, left, vn, vi) +""" + tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo + ) + +This function handles observed variables, which may be: + +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. + +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). + +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. + +Observations of submodels are not yet supported in DynamicPPL. +""" +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(childcontext(context), right, left, vn, vi) end - -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal # value. For the need for prefix_and_strip_contexts rather than just prefix, see the # comment in `tilde_assume!!`. @@ -50,21 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi) end return tilde_observe!!(new_context, right, left, new_vn, vi) end - -""" - tilde_observe!!(context, right, left, vn, vi) - -Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name -and indices; if needed, these can be accessed through this function, though. -""" -function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi) +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi) +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, +) throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 7eea73f66..4baca1b57 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -191,6 +191,12 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!(::InitContext, right, left, vn, vi) +function tilde_observe!!( + ::InitContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end From 992569f27bc4078ec405c7f2d8987a71e9d3ca4a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 12:35:31 +0100 Subject: [PATCH 5/8] Fix tests --- test/lkj.jl | 5 ++--- test/sampler.jl | 54 ------------------------------------------------- 2 files changed, 2 insertions(+), 57 deletions(-) diff --git a/test/lkj.jl b/test/lkj.jl index 03e744b84..5c5603aba 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -21,7 +21,7 @@ _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - for init_strategy in [PriorInit(), UniformInit()] + for init_strategy in [InitFromPrior(), InitFromUniform()] samples = [ last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples ] @@ -32,8 +32,7 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) - # `SampleFromPrior` will sample in unconstrained space. - for init_strategy in [PriorInit(), UniformInit()] + for init_strategy in [InitFromPrior(), InitFromUniform()] samples = [ last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples ] diff --git a/test/sampler.jl b/test/sampler.jl index c812de938..5380ad17e 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -113,60 +113,6 @@ end end - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # will be drawn from U[-2, 2] and its mean should be 0. - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling abstract type OnlyInitAlg end From 6974cc1b06618d62936bef3e4df30245e46c21ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 14:19:40 +0100 Subject: [PATCH 6/8] fix ambiguity --- src/transforming.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transforming.jl b/src/transforming.jl index 5465b2ff2..589dca031 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -13,7 +13,10 @@ struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, right::Distribution, vn, vi + ::DynamicTransformationContext{isinverse}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +34,13 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right::Distribution, left, vn, vi) +function tilde_observe!!( + ::DynamicTransformationContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end From 77a87101a67f64ba8e5f41dc4df3445c89c8c60a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Sep 2025 14:55:26 +0100 Subject: [PATCH 7/8] Add changelog --- HISTORY.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ddbe67842..a71bb6bd1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,8 +2,66 @@ ## 0.38.0 -The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. -Their behaviour is otherwise identical. +**Breaking changes** + +### Introduction of `InitContext` + +DynamicPPL 0.38 introduces a new evaluation context, `InitContext`. +It is used to generate fresh values for random variables in a model. + +Evaluation contexts are stored inside a `DynamicPPL.Model` object, and control what happens with tilde-statements when a model is run. +The two major leaf (basic) contexts are `DefaultContext` and, now, `InitContext`. +`DefaultContext` is the default context, and it simply uses the values that are already stored in the `VarInfo` object passed to the model evaluation function. +On the other hand, `InitContext` ignores values in the VarInfo object and inserts new values obtained from a specified source. +(It follows also that the VarInfo being used may be empty, which means that `InitContext` is now also the way to obtain a fresh VarInfo for a model.) + +DynamicPPL 0.38 provides three flavours of _initialisation strategies_, which are specified as the second argument to `InitContext`: + + - `InitContext(rng, InitFromPrior())`: New values are sampled from the prior distribution (on the right-hand side of the tilde). + - `InitContext(rng, InitFromUniform(a, b))`: New values are sampled uniformly from the interval `[a, b]`, and then invlinked to the support of the distribution on the right-hand side of the tilde. + - `InitContext(rng, InitFromParams(p, fallback))`: New values are obtained by indexing into the `p` object, which can be a `NamedTuple` or `Dict{<:VarName}`. If a variable is not found in `p`, then the `fallback` strategy is used, which is simply another of these strategies. In particular, `InitFromParams` enables the case where different variables are to be initialised from different sources. + +(It is possible to define your own initialisation strategy; users who wish to do so are referred to the DynamicPPL API documentation and source code.) + +**The main impact on the upcoming Turing.jl release** is that, instead of providing initial values for sampling, the user will be expected to provide an initialisation strategy instead. +This is a more flexible approach, and not only solves a number of pre-existing issues with initialisation of Turing models, but also improves the clarity of user code. +In particular: + + - When providing a set of fixed parameters (i.e. `InitFromParams(p)`), `p` must now either be a NamedTuple or a Dict. Previously Vectors were allowed, which is error-prone because the ordering of variables in a VarInfo is not obvious. + - The parameters in `p` must now always be provided in unlinked space (i.e., in the space of the distribution on the right-hand side of the tilde). Previously, whether a parameter was expected to be in linked or unlinked space depended on whether the VarInfo was linked or not, which was confusing. + +### Removal of `SamplingContext` + +For developers working on DynamicPPL, `InitContext` now completely replaces what used to be `SamplingContext`, `SampleFromPrior`, and `SampleFromUniform`. +Evaluating a model with `SamplingContext(SampleFromPrior())` (e.g. with `DynamicPPL.evaluate_and_sample!!(model, VarInfo(), SampleFromPrior())` has a direct one-to-one replacement in `DynamicPPL.init!!(model, VarInfo(), InitFromPrior())`. +Please see the docstring of `init!!` for more details. +Likewise `SampleFromUniform()` can be replaced with `InitFromUniform()`. +`InitFromParams()` provides new functionality which previously used to be implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. + +The main change that this is likely to create is for those who are implementing samplers or inference algorithms. +The exact way in which this happens will be detailed in the Turing.jl changelog when a new release is made. +Broadly speaking, though, `SamplingContext(MySampler())` will be removed so if your sampler needs custom behaviour with the tilde-pipeline you will likely have to define your own context. + +### Simplification of the tilde-pipeline + +There are now only two functions in the tilde-pipeline that need to be overloaded to change the behaviour of tilde-statements, namely, `tilde_assume!!` and `tilde_observe!!`. +Other functions such as `tilde_assume` and `assume` (and their `observe` counterparts) have been removed. + +Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other). +The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details). + +**Other changes** + +### Reimplementation of functions using `InitContext` + +A number of functions have been reimplemented and unified with the help of `InitContext`. +In particular, this release brings substantial performance improvements for `returned` and `predict`. +Their APIs are the same. + +### Upstreaming of VarName functionality + +The implementation of the `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical, and they are still accessible from the DynamicPPL module (though still not exported). ## 0.37.3 From f4e5f4b369cdfad04d77f303e92373dd0479b341 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Sep 2025 16:20:21 +0100 Subject: [PATCH 8/8] Update HISTORY.md Co-authored-by: Markus Hauru --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index a71bb6bd1..d67afcbfe 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -36,7 +36,7 @@ For developers working on DynamicPPL, `InitContext` now completely replaces what Evaluating a model with `SamplingContext(SampleFromPrior())` (e.g. with `DynamicPPL.evaluate_and_sample!!(model, VarInfo(), SampleFromPrior())` has a direct one-to-one replacement in `DynamicPPL.init!!(model, VarInfo(), InitFromPrior())`. Please see the docstring of `init!!` for more details. Likewise `SampleFromUniform()` can be replaced with `InitFromUniform()`. -`InitFromParams()` provides new functionality which previously used to be implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. +`InitFromParams()` provides new functionality which was previously implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. The main change that this is likely to create is for those who are implementing samplers or inference algorithms. The exact way in which this happens will be detailed in the Turing.jl changelog when a new release is made.