Skip to content

Bug: empty free variables in positive phase results in index error#37

Open
sugolov wants to merge 2 commits intoextropic-ai:mainfrom
sugolov:anton/empty-pos-fix
Open

Bug: empty free variables in positive phase results in index error#37
sugolov wants to merge 2 commits intoextropic-ai:mainfrom
sugolov:anton/empty-pos-fix

Conversation

@sugolov
Copy link
Copy Markdown

@sugolov sugolov commented Mar 12, 2026

Summary

If there are no hidden nodes, or all variables are clamped during the positive phase, estimate_kl_grad does not handle this edge case. Specifically, in the below contrastive divergence gradient

$$ \underbrace{ \mathbb{E}_{p_\text{data}} [\nabla_\theta \mathcal{E}_\theta (x)] }_\text{positive phase} - \underbrace{ \mathbb{E}_{p_\theta} [\nabla_\theta \mathcal{E}_\theta (x)] }_\text{negative phase} $$

thrml does not handle the case where $x$ is entirely visible and clamped to $x \sim p_\text{data}$. estimate_kl_grad implicitly assumes there are always free variables in the init_positive, by subindexing the list to create keys for sampling.

keys_pos = jax.random.split(key_pos, init_state_positive[0].shape[:2])

moms_b_pos, moms_w_pos = jax.vmap(
    lambda k_out, i_out: jax.vmap(
        lambda k, i, c: estimate_moments(
            k, bias_nodes, weight_edges, training_spec.program_positive, training_spec.schedule_positive, i, c
        )
    )(k_out, i_out, data + cond_batched_pos)
)(keys_pos, init_state_positive)

This throws the expected IndexError

Minimal reproduction

import jax
import jax.numpy as jnp
import networkx as nx
from thrml.pgm import SpinNode
from thrml.models.ising import (
    IsingEBM, IsingTrainingSpec, estimate_kl_grad, hinton_init
)
from thrml.block_management import Block
from thrml.block_sampling import SamplingSchedule

G = nx.grid_2d_graph(4, 4)
nodes = [SpinNode() for _ in range(16)]
G = nx.relabel_nodes(G, {coord: nodes[i] for i, coord in enumerate(G.nodes)})
edges = list(G.edges)

key = jax.random.key(0)
model = IsingEBM(
    nodes, edges,
    biases=jax.random.normal(key, (16,)),
    weights=jax.random.normal(key, (len(edges),)),
    beta=jnp.array(1.0),
)

coloring = nx.coloring.greedy_color(G, strategy="DSATUR")
free_blocks = [
    Block([n for n, c in coloring.items() if c == color])
    for color in range(max(coloring.values()) + 1)
]

training_spec = IsingTrainingSpec(
    model,
    data_blocks=[Block(nodes)],
    conditioning_blocks=[],
    positive_sampling_blocks=[],
    negative_sampling_blocks=free_blocks,
    schedule_positive=SamplingSchedule(n_warmup=0, n_samples=1, steps_per_sample=0),
    schedule_negative=SamplingSchedule(n_warmup=0, n_samples=1, steps_per_sample=1),
)

batch_size = 32
data = jax.random.bernoulli(key, 0.5, (batch_size, 16)).astype(jnp.bool_)
init_free = hinton_init(key, model, free_blocks, (batch_size,))

# Crashes: IndexError: list index out of range
estimate_kl_grad(key, training_spec, nodes, edges, [data], [], [], init_free)

Fix: explicitly handle this case in estimate_kl_grad

if len(init_state_positive) == 0:
    # if there are no initial states in pos sampling
    # data[0]: (batch, n_nodes)
    spins = 2 * data[0].astype(float_type) - 1          # (batch, n_nodes)

    ei_idx = jnp.array([bias_nodes.index(e[0]) for e in weight_edges])  # (n_edges,)
    ej_idx = jnp.array([bias_nodes.index(e[1]) for e in weight_edges])  # (n_edges,)

    moms_b_pos = spins                                   # (batch, n_nodes)
    moms_w_pos = spins[:, ei_idx] * spins[:, ej_idx]     # (batch, n_edges)

    moms_b_pos = moms_b_pos[None]                        # (1, batch, n_nodes)
    moms_w_pos = moms_w_pos[None]                        # (1, batch, n_edges)
else:
    keys_pos = jax.random.split(key_pos, init_state_positive[0].shape[:2])
    moms_b_pos, moms_w_pos = jax.vmap(
        lambda k_out, i_out: jax.vmap(
            lambda k, i, c: estimate_moments(
                k, bias_nodes, weight_edges, training_spec.program_positive, training_spec.schedule_positive, i, c
            )
        )(k_out, i_out, data + cond_batched_pos)
    )(keys_pos, init_state_positive)

Comments

Let me know if there's a more logical way of handling the shapes, or to set moms_b_pos, moms_w_pos to match the previous semantics

@lockwo lockwo self-requested a review March 13, 2026 07:13
Comment thread tests/test_ising.py


class TestEstimateKLGradFullyVisible(unittest.TestCase):
"""Test that estimate_kl_grad works when all nodes are visible (no latent variables)."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the parenthetical statement isn't needed

Comment thread tests/test_ising.py
"""Test that estimate_kl_grad works when all nodes are visible (no latent variables)."""

def test_fully_visible_ising(self):
import networkx as nx
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be top level import?

Comment thread tests/test_ising.py
conditioning_blocks=[],
positive_sampling_blocks=[],
negative_sampling_blocks=free_blocks,
schedule_positive=SamplingSchedule(n_warmup=0, n_samples=1, steps_per_sample=0),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems weird we even need to provide a sampling schedule for positive if we don't sample? Should we change this?

Comment thread tests/test_ising.py
init_free,
)

# Gradients should be finite and have correct shapes
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add check that gradient is just energy or you think this is good enough?

Comment thread thrml/models/ising.py
# data[0]: (batch, n_nodes)
spins = 2 * data[0].astype(float_type) - 1 # (batch, n_nodes)

ei_idx = jnp.array([bias_nodes.index(e[0]) for e in weight_edges]) # (n_edges,)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we do this elsewhere? Just want to make sure this unroll isn't painful compilation wise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants