From d0478f956c34a36e0e2cbb9b45692c89de165524 Mon Sep 17 00:00:00 2001 From: DLPerf <1870988096@qq.com> Date: Mon, 6 Sep 2021 09:37:16 +0800 Subject: [PATCH] Improve performance --- rex_gym/agents/ppo/memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rex_gym/agents/ppo/memory.py b/rex_gym/agents/ppo/memory.py index b38b718..a7d507c 100644 --- a/rex_gym/agents/ppo/memory.py +++ b/rex_gym/agents/ppo/memory.py @@ -71,10 +71,10 @@ def append(self, transitions, rows=None): self._max_length, message='max length exceeded') append_ops = [] + timestep = tf.gather(self._length, rows) + indices = tf.stack([rows, timestep], 1) with tf.control_dependencies([assert_max_length]): for buffer_, elements in zip(self._buffers, transitions): - timestep = tf.gather(self._length, rows) - indices = tf.stack([rows, timestep], 1) append_ops.append(tf.compat.v1.scatter_nd_update(buffer_, indices, elements)) with tf.control_dependencies(append_ops): episode_mask = tf.reduce_sum(tf.one_hot(rows, self._capacity, dtype=tf.int32), 0)