@@ -53,22 +53,20 @@ def __init__(
5353 self ._action = None
5454 self ._frames_seen = 0
5555
56- def act (self , state , reward ):
57- self .replay_buffer .store (self ._state , self ._action , reward , state )
56+ def act (self , state ):
57+ self .replay_buffer .store (self ._state , self ._action , state )
5858 self ._train ()
5959 self ._state = state
6060 self ._action = self ._choose_action (state )
6161 return self ._action
6262
63- def eval (self , state , _ ):
64- return self ._best_actions (self .q_dist .eval (state ))
63+ def eval (self , state ):
64+ return self ._best_actions (self .q_dist .eval (state )). item ()
6565
6666 def _choose_action (self , state ):
6767 if self ._should_explore ():
68- return torch .randint (
69- self .q_dist .n_actions , (len (state ),), device = self .q_dist .device
70- )
71- return self ._best_actions (self .q_dist .no_grad (state ))
68+ return np .random .randint (0 , self .q_dist .n_actions )
69+ return self ._best_actions (self .q_dist .no_grad (state )).item ()
7270
7371 def _should_explore (self ):
7472 return (
@@ -77,8 +75,8 @@ def _should_explore(self):
7775 )
7876
7977 def _best_actions (self , probs ):
80- q_values = (probs * self .q_dist .atoms ).sum (dim = 2 )
81- return torch .argmax (q_values , dim = 1 )
78+ q_values = (probs * self .q_dist .atoms ).sum (dim = - 1 )
79+ return torch .argmax (q_values , dim = - 1 )
8280
8381 def _train (self ):
8482 if self ._should_train ():
0 commit comments