diff --git a/dopamine/jax/networks.py b/dopamine/jax/networks.py index 56efdb0b..99452a5b 100644 --- a/dopamine/jax/networks.py +++ b/dopamine/jax/networks.py @@ -257,6 +257,6 @@ def apply(self, x, num_actions, num_atoms): x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) - probabilities = nn.softmax(logits) + diracs = logits # nn.softmax(logits) q_values = jnp.mean(logits, axis=2) - return atari_lib.RainbowNetworkType(q_values, logits, probabilities) + return atari_lib.RainbowNetworkType(q_values, logits, diracs)