From e53071795b5204f9fa63067af8fc1af00c3ee4d3 Mon Sep 17 00:00:00 2001 From: ddlau Date: Tue, 24 Nov 2020 18:26:21 +0800 Subject: [PATCH] a bug-fix of QR-DQN network definition. --- dopamine/jax/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dopamine/jax/networks.py b/dopamine/jax/networks.py index 56efdb0b..99452a5b 100644 --- a/dopamine/jax/networks.py +++ b/dopamine/jax/networks.py @@ -257,6 +257,6 @@ def apply(self, x, num_actions, num_atoms): x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) - probabilities = nn.softmax(logits) + diracs = logits # nn.softmax(logits) q_values = jnp.mean(logits, axis=2) - return atari_lib.RainbowNetworkType(q_values, logits, probabilities) + return atari_lib.RainbowNetworkType(q_values, logits, diracs)