Skip to content

loadstate doesn't differentiate between single-chain sampling and multi-chain sampling with n_chains = 1 #1049

@penelopeysm

Description

@penelopeysm

Consider the following invocations:

chain1 = sample(model, spl, 100)
chain2 = sample(model, spl, MCMCThreads(), 100, 1)

sample(model, spl, 100; initial_state=DynamicPPL.loadstate(chain1))
sample(model, spl, MCMCThreads(), 100, 1; initial_state=DynamicPPL.loadstate(chain2))

AbstractMCMC requires different types for initial_state. In the third line it has to just be a state object, whereas the fourth line requires it to be a vector with the state as its only element. This isn't an issue with AbstractMCMC at all, that's totally fine.

The issue lies with the DynamicPPL.loadstate function. The problem is that chain1 and chain2 will, in general, have the same contents*, and thus calling loadstate() on the output of the two calls above will yield the same thing. This means that we have to make a choice of what loadstate returns when called on a single chain: either it returns a vector, or it returns just the state. That then means that we will break one of the two calls to sample.

I think loadstate should have more arguments specifying whether it is being passed to parallel sample, and if so, the number of chains. That also makes it possible to yield better error messages if incompatible chains are used for save/resume.

This is a really, really small thing, so hardly high priority, but it is a rough edge on the current interface.

Also, this is not to say that bigger refactoring isn't needed for resume_from / initial_state; but this is also quite an easy fix and so I don't think this would hold up any other refactoring.

* For MCMCChains, this was true prior to v7.3.0. Since then, chain1 and chain2 will in fact differ. Internally, chain1 will store its state as a single state object, and chain2 will store it as a singleton vector. But I don't think it should, really, and in fact I think the design of MCMCChains in this area is wrong. It should not matter which invocation the chain was sampled with; the fact is that both of those lines sample a single chain, and the returned data structure should be agnostic towards how that single chain was sampled. So I think it should always be stored as a vector, which is what I've done in FlexiChains.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions