@@ -56,16 +56,28 @@ def test_converge(self):
5656 self .assertLess (loss , 0.2 )
5757
5858 def test_scaling (self ):
59- self .space = Box (np .array ([- 10 , - 5 , 100 ]), np .array ([10 , - 2 , 200 ]))
60- self .policy = SoftDeterministicPolicy (
59+ torch .manual_seed (0 )
60+ state = State (torch .randn (1 , STATE_DIM ))
61+ policy1 = SoftDeterministicPolicy (
6162 self .model ,
6263 self .optimizer ,
63- self . space
64+ Box ( np . array ([ - 1. , - 1. , - 1. ]), np . array ([ 1. , 1. , 1. ]))
6465 )
66+ action1 , log_prob1 = policy1 (state )
67+
68+ # reset seed and sample same thing, but with different scaling
69+ torch .manual_seed (0 )
6570 state = State (torch .randn (1 , STATE_DIM ))
66- action , log_prob = self .policy (state )
67- tt .assert_allclose (action , torch .tensor ([[- 3.09055 , - 4.752777 , 188.98222 ]]))
68- tt .assert_allclose (log_prob , torch .tensor ([- 0.397002 ]), rtol = 1e-4 )
71+ policy2 = SoftDeterministicPolicy (
72+ self .model ,
73+ self .optimizer ,
74+ Box (np .array ([- 2. , - 1. , - 1. ]), np .array ([2. , 1. , 1. ]))
75+ )
76+ action2 , log_prob2 = policy2 (state )
77+
78+ # check scaling was correct
79+ tt .assert_allclose (action1 * torch .tensor ([2 , 1 , 1 ]), action2 )
80+ tt .assert_allclose (log_prob1 - np .log (2 ), log_prob2 )
6981
7082
7183if __name__ == '__main__' :
0 commit comments