Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.8.1"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "1"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
Distributions = "0.22, 0.23"
ExprTools = "0.1.1"
MacroTools = "0.5.1"
ZygoteRules = "0.2"
julia = "1"
Expand Down
6 changes: 1 addition & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Bijectors
using MacroTools

import AbstractMCMC
import ExprTools
import ZygoteRules

import Random
Expand Down Expand Up @@ -51,25 +52,20 @@ export AbstractVarInfo,
inspace,
subsumes,
# Compiler
ModelGen,
@model,
@varname,
# Utilities
vectorize,
reconstruct,
reconstruct!,
Sample,
Chain,
init,
vectorize,
set_resume!,
# Model
ModelGen,
Model,
getmissings,
getargnames,
getdefaults,
getgenerator,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
111 changes: 71 additions & 40 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,27 @@ end
Builds the `model_info` dictionary from the model's expression.
"""
function build_model_info(input_expr)
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
modeldef = MacroTools.splitdef(input_expr)
# Function body of the model is empty
# Break up the model definition and extract its name, arguments, and function body
modeldef = ExprTools.splitdef(input_expr)

# Print a warning if function body of the model is empty
warn_empty(modeldef[:body])
# Construct model_info dictionary

## Construct model_info dictionary

# Shortcut if the model does not have any arguments
if !haskey(modeldef, :args)
modelinfo = Dict(
:name => modeldef[:name],
:main_body => modeldef[:body],
:arg_syms => [],
:args_nt => NamedTuple(),
:defaults_nt => NamedTuple(),
:args => [],
:modeldef => modeldef,
)
return modelinfo
end

# Extracting the argument symbols from the model definition
arg_syms = map(modeldef[:args]) do arg
Expand Down Expand Up @@ -158,7 +174,7 @@ function build_model_info(input_expr)
:args_nt => args_nt,
:defaults_nt => defaults_nt,
:args => args,
:whereparams => modeldef[:whereparams]
:modeldef => modeldef,
)

return model_info
Expand Down Expand Up @@ -318,48 +334,63 @@ hasmissing(T::Type) = false
Builds the output expression.
"""
function build_output(model_info)
# Arguments with default values
## Build the anonymous evaluator from the user-provided model definition

# Remove the name and use `function (....)` syntax
modeldef = model_info[:modeldef]
delete!(modeldef, :name)
modeldef[:head] = :function

# Define the input arguments (positional + keyword arguments), without default values
origargs = map(vcat(get(modeldef, :args, Any[]), get(modeldef, :kwargs, Any[]))) do arg
Meta.isexpr(arg, :kw) && length(arg.args) >= 1 ? arg.args[1] : arg
end

# Add our own arguments
newargs = Any[:(_rng::$(Random.AbstractRNG)),
:(_model::$(DynamicPPL.Model)),
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
:(_sampler::$(DynamicPPL.AbstractSampler)),
:(_context::$(DynamicPPL.AbstractContext))]
combinedargs = vcat(newargs, origargs)

# Delete keyword arguments and update positional arguments
delete!(modeldef, :kwargs)
modeldef[:args] = combinedargs

# Replace function body
modeldef[:body] = model_info[:main_body]

## Extract other relevant information

# All arguments with default values (if existent)
args = model_info[:args]
# Argument symbols without default values
arg_syms = model_info[:arg_syms]
# Arguments namedtuple
# Named tuple of all arguments
args_nt = model_info[:args_nt]
# Default values of the arguments
# Arguments namedtuple

# Named tuple of the default values of the arguments
defaults_nt = model_info[:defaults_nt]
# Where parameters
whereparams = model_info[:whereparams]
# Model generator name
model_gen = model_info[:name]
# Main body of the model
main_body = model_info[:main_body]

unwrap_data_expr = Expr(:block)
for var in arg_syms
push!(unwrap_data_expr.args,
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end

@gensym(evaluator, generator)
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
# Model name
model = model_info[:name]

return quote
function $evaluator(
_rng::$(Random.AbstractRNG),
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.AbstractVarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$main_body
end
# Define model definition with only keyword arguments
if isempty(args)
model_kwform = ()
else
# All arguments without default values (i.e., only symbols)
arg_syms = model_info[:arg_syms]

$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
$(generator_kw_form...)
model_kwform = (:($model(; $(args...)) = $model($(arg_syms...))),)
end

$(Base).@__doc__ $model_gen = $model_gen_constructor
@gensym(evaluator)
return quote
$(Base).@__doc__ function $model($(args...))
$evaluator = $(ExprTools.combinedef(modeldef))
return $(DynamicPPL.Model)($evaluator, $args_nt, $defaults_nt)
end
$(model_kwform...)
end
end

Expand Down
Loading