-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Consider the following invocations:
chain1 = sample(model, spl, 100)
chain2 = sample(model, spl, MCMCThreads(), 100, 1)
sample(model, spl, 100; initial_state=DynamicPPL.loadstate(chain1))
sample(model, spl, MCMCThreads(), 100, 1; initial_state=DynamicPPL.loadstate(chain2))
AbstractMCMC requires different types for initial_state
. In the third line it has to just be a state object, whereas the fourth line requires it to be a vector with the state as its only element. This isn't an issue with AbstractMCMC at all, that's totally fine.
The issue lies with the DynamicPPL.loadstate
function. The problem is that chain1
and chain2
will, in general, have the same contents*, and thus calling loadstate()
on the output of the two calls above will yield the same thing. This means that we have to make a choice of what loadstate
returns when called on a single chain: either it returns a vector, or it returns just the state. That then means that we will break one of the two calls to sample.
I think loadstate should have more arguments specifying whether it is being passed to parallel sample, and if so, the number of chains. That also makes it possible to yield better error messages if incompatible chains are used for save/resume.
This is a really, really small thing, so hardly high priority, but it is a rough edge on the current interface.
Also, this is not to say that bigger refactoring isn't needed for resume_from / initial_state; but this is also quite an easy fix and so I don't think this would hold up any other refactoring.
* For MCMCChains, this was true prior to v7.3.0. Since then, chain1
and chain2
will in fact differ. Internally, chain1
will store its state as a single state object, and chain2
will store it as a singleton vector. But I don't think it should, really, and in fact I think the design of MCMCChains in this area is wrong. It should not matter which invocation the chain was sampled with; the fact is that both of those lines sample a single chain, and the returned data structure should be agnostic towards how that single chain was sampled. So I think it should always be stored as a vector, which is what I've done in FlexiChains.