Hi,
There is a problem in cmhe module.
def predict_latent_z(model, x):
model, _ = model
gates = model.model.embedding(x)
z_gate_probs = torch.exp(gates).sum(axis=2).detach().numpy()
return z_gate_probs
there is <model.model> which makes it have problems while running.
also, axis=2 is hardcoded and raises problem with k>1.