|  | 
|  | 1 | +--- | 
|  | 2 | +title: Predictive Distributions | 
|  | 3 | +engine: julia | 
|  | 4 | +--- | 
|  | 5 | + | 
|  | 6 | +```{julia} | 
|  | 7 | +#| echo: false | 
|  | 8 | +#| output: false | 
|  | 9 | +using Pkg; | 
|  | 10 | +Pkg.instantiate(); | 
|  | 11 | +``` | 
|  | 12 | + | 
|  | 13 | +Standard MCMC sampling methods return values of the parameters of the model. | 
|  | 14 | +However, it is often also useful to generate new data points using the model, given a distribution of the parameters. | 
|  | 15 | +Turing.jl allows you to do this using the `predict` function, along with conditioning syntax. | 
|  | 16 | + | 
|  | 17 | +Consider the following simple model, where we observe some normally-distributed data `X` and want to learn about its mean `m`. | 
|  | 18 | + | 
|  | 19 | +```{julia} | 
|  | 20 | +using Turing | 
|  | 21 | +@model function f(N) | 
|  | 22 | +    m ~ Normal() | 
|  | 23 | +    X ~ filldist(Normal(m), N) | 
|  | 24 | +end | 
|  | 25 | +``` | 
|  | 26 | + | 
|  | 27 | +Notice first how we have not specified `X` as an argument to the model. | 
|  | 28 | +This allows us to use Turing's conditioning syntax to specify whether we want to provide observed data or not. | 
|  | 29 | + | 
|  | 30 | +::: {.callout-note} | 
|  | 31 | +If you want to specify `X` as an argument to the model, then to mark it as being unobserved, you have to instantiate the model again with `X = missing` or `X = fill(missing, N)`. | 
|  | 32 | +Whether you use `missing` or `fill(missing, N)` depends on whether `X` is treated as a single distribution (e.g. with `filldist` or `product_distribution`), or as multiple independent distributions (e.g. with `.~` or a for loop over `eeachindex(X)`). | 
|  | 33 | +This is rather finicky, so we recommend using the current approach: conditioning and deconditioning `X` as a whole should work regardless of how `X` is defined in the model. | 
|  | 34 | +::: | 
|  | 35 | + | 
|  | 36 | +```{julia} | 
|  | 37 | +# Generate some synthetic data | 
|  | 38 | +N = 5 | 
|  | 39 | +true_m = 3.0 | 
|  | 40 | +X = rand(Normal(true_m), N) | 
|  | 41 | +
 | 
|  | 42 | +# Instantiate the model with observed data | 
|  | 43 | +model = f(N) | (; X = X) | 
|  | 44 | +
 | 
|  | 45 | +# Sample from the posterior | 
|  | 46 | +chain = sample(model, NUTS(), 1_000; progress=false) | 
|  | 47 | +mean(chain[:m]) | 
|  | 48 | +``` | 
|  | 49 | + | 
|  | 50 | +## Posterior predictive distribution | 
|  | 51 | + | 
|  | 52 | +`chain[:m]` now contains samples from the posterior distribution of `m`. | 
|  | 53 | +If we use these samples of the parameters to generate new data points, we obtain the *posterior predictive distribution*. | 
|  | 54 | +Statistically, this is defined as | 
|  | 55 | + | 
|  | 56 | +$$ | 
|  | 57 | +p(\tilde{x} | \theta, \mathbf{X}) = \int p(\tilde{x} | \theta) p(\theta | \mathbf{X}) d\theta, | 
|  | 58 | +$$ | 
|  | 59 | + | 
|  | 60 | +where $\tilde{x}$ is the new data which you wish to draw, $\theta$ are the model parameters, and $\mathbf{X}$ is the observed data. | 
|  | 61 | +$p(\tilde{x} | \theta)$ is the distribution of the new data given the parameters, which is specified in the Turing.jl model (the `X ~ ...` line); and $p(\theta | \mathbf{X})$ is the posterior distribution, as given by the Markov chain. | 
|  | 62 | + | 
|  | 63 | +To obtain samples of $\tilde{x}$, we need to first remove the observed data from the model (or 'decondition' it). | 
|  | 64 | +This means that when the model is evaluated, it will sample a new value for `X`. | 
|  | 65 | + | 
|  | 66 | +```{julia} | 
|  | 67 | +predictive_model = decondition(model) | 
|  | 68 | +``` | 
|  | 69 | + | 
|  | 70 | +::: {.callout-tip} | 
|  | 71 | +## Selective deconditioning | 
|  | 72 | + | 
|  | 73 | +If you only want to decondition a single variable `X`, you can use `decondition(model, @varname(X))`. | 
|  | 74 | +::: | 
|  | 75 | + | 
|  | 76 | +To demonstrate how this deconditioned model can generate new data, we can fix the value of `m` to be its mean and evaluate the model: | 
|  | 77 | + | 
|  | 78 | +```{julia} | 
|  | 79 | +predictive_model_with_mean_m = predictive_model | (; m = mean(chain[:m])) | 
|  | 80 | +rand(predictive_model_with_mean_m) | 
|  | 81 | +``` | 
|  | 82 | + | 
|  | 83 | +This has given us a single sample of `X` given the mean value of `m`. | 
|  | 84 | +Of course, to take our Bayesian uncertainty into account, we want to use the full posterior distribution of `m`, not just its mean. | 
|  | 85 | +To do so, we use `predict`, which _effectively_ does the same as above but for every sample in the chain: | 
|  | 86 | + | 
|  | 87 | +```{julia} | 
|  | 88 | +predictive_samples = predict(predictive_model, chain) | 
|  | 89 | +``` | 
|  | 90 | + | 
|  | 91 | +::: {.callout-tip} | 
|  | 92 | +## Reproducibility | 
|  | 93 | + | 
|  | 94 | +`predict`, like many other Julia functions, takes an optional `rng` as its first argument. | 
|  | 95 | +This controls the generation of new `X` samples, and makes your results reproducible. | 
|  | 96 | +::: | 
|  | 97 | + | 
|  | 98 | +::: {.callout-note} | 
|  | 99 | +`predict` returns a Chains object itself, which will only contain the newly predicted variables. | 
|  | 100 | +If you want to also retain the original parameters, you can use `predict(rng, predictive_model, chain; include_all=true)`. | 
|  | 101 | +Note that the `include_all` keyword argument does not work unless you also pass an RNG as the first argument; you can use `Random.default_rng()` if you aren't fussed. | 
|  | 102 | +(This will be fixed in the next release of Turing.) | 
|  | 103 | +::: | 
|  | 104 | + | 
|  | 105 | +We can visualise the predictive distribution by combining all the samples and making a density plot: | 
|  | 106 | + | 
|  | 107 | +```{julia} | 
|  | 108 | +using StatsPlots: density, density!, vline! | 
|  | 109 | +
 | 
|  | 110 | +predicted_X = vcat([predictive_samples[Symbol("X[$i]")] for i in 1:N]...) | 
|  | 111 | +density(predicted_X, label="Posterior predictive") | 
|  | 112 | +``` | 
|  | 113 | + | 
|  | 114 | +## Prior predictive distribution | 
|  | 115 | + | 
|  | 116 | +Alternatively, if we use the prior distribution of the parameters, we obtain the *prior predictive distribution*: | 
|  | 117 | + | 
|  | 118 | +$$ | 
|  | 119 | +p(\tilde{x}) = \int p(\tilde{x} | \theta) p(\theta) d\theta, | 
|  | 120 | +$$ | 
|  | 121 | + | 
|  | 122 | +This is simpler, as there is no need to pass a chain in: we can sample from the deconditioned model directly, using Turing's `Prior` sampler. | 
|  | 123 | + | 
|  | 124 | +```{julia} | 
|  | 125 | +prior_predictive_samples = sample(predictive_model, Prior(), 1_000; progress=false) | 
|  | 126 | +``` | 
|  | 127 | + | 
|  | 128 | +We can visualise the prior predictive distribution in the same way as before. | 
|  | 129 | +Let's compare the two predictive distributions: | 
|  | 130 | + | 
|  | 131 | +```{julia} | 
|  | 132 | +prior_predicted_X = vcat([prior_predictive_samples[Symbol("X[$i]")] for i in 1:N]...) | 
|  | 133 | +density(prior_predicted_X, label="Prior predictive") | 
|  | 134 | +density!(predicted_X, label="Posterior predictive") | 
|  | 135 | +vline!([true_m], label="True mean", linestyle=:dash, color=:black) | 
|  | 136 | +``` | 
|  | 137 | + | 
|  | 138 | +We can see here that the prior predictive distribution is: | 
|  | 139 | + | 
|  | 140 | +1. Wider than the posterior predictive distribution; | 
|  | 141 | +2. Centred on the prior mean of `m` (which is 0), rather than the posterior mean (which is close to the true mean of `3`). | 
|  | 142 | + | 
|  | 143 | +Both of these are because the posterior predictive distribution has been informed by the observed data. | 
0 commit comments