@@ -85,15 +85,8 @@ def __init__(self,
8585 self .softmax = nn .Softmax (dim = 0 )
8686 self .log_softmax = nn .LogSoftmax (dim = 0 )
8787
88- def compute_contrastive_loss (self , rollouts , advantages ):
89- """
90- Contrastive Predictive Coding for learning representation and density ratio
91- :param rollouts: Storage's instance
92- :param advantage: tensor of shape: (timestep, n_processes, 1)
93- """
94- # FIXME: only compatible with 1D observation
95- num_steps , n_processes , _ = advantages .shape
96-
88+ def _encode_input_sequences (self , rollouts ):
89+ num_steps , n_processes , _ = rollouts .actions .shape
9790 # INPUT SEQUENCES AND MASKS
9891 # the stochastic input will be defined by last 2 scalar
9992 input_seq = rollouts .obs [1 :, :, - 2 :]
@@ -104,7 +97,7 @@ def compute_contrastive_loss(self, rollouts, advantages):
10497 # encode the input sequence
10598 # Let's figure out which steps in the sequence have a zero for any agent
10699 # We will always assume t=0 has a zero in it as that makes the logic cleaner
107- has_zeros = ((masks [1 : ] == 0.0 )
100+ has_zeros = ((masks [: - 1 ] == 0.0 )
108101 .any (dim = - 1 )
109102 .nonzero ()
110103 .squeeze ()
@@ -128,8 +121,13 @@ def compute_contrastive_loss(self, rollouts, advantages):
128121 start_idx = has_zeros [i ]
129122 end_idx = has_zeros [i + 1 ]
130123
131- output , _ = self .input_seq_encoder (
132- input_seq [start_idx + 1 : end_idx + 1 ])
124+ try :
125+ output , _ = self .input_seq_encoder (
126+ input_seq [start_idx + 1 : end_idx + 1 ])
127+ except :
128+ print (start_idx )
129+ print (end_idx )
130+ print (has_zeros )
133131
134132 outputs .append (output )
135133
@@ -139,12 +137,21 @@ def compute_contrastive_loss(self, rollouts, advantages):
139137 # reverse back
140138 input_seq = torch .flip (input_seq , [0 ])
141139
140+ return input_seq
141+
142+ def _encode_advantages (self , advantages ):
143+ # FIXME: only compatible with 1D observation
144+ num_steps , n_processes , _ = advantages .shape
142145 # ADVANTAGES
143146 # encode
144147 # n_steps x n_process x hidden_dim/2
145148 advantages = self .advantage_encoder (
146149 advantages .reshape (- 1 , 1 )).reshape (num_steps , n_processes , - 1 )
147150
151+ return advantages
152+
153+ def _encode_states (self , rollouts ):
154+ num_steps , n_processes , _ = rollouts .actions .shape
148155 # STATES
149156 # encode
150157 # n_steps x n_process x hidden_dim/2
@@ -154,29 +161,52 @@ def compute_contrastive_loss(self, rollouts, advantages):
154161 states = self .state_encoder (
155162 states .reshape (- 1 , states_shape )).reshape (num_steps , n_processes , - 1 )
156163
164+ return states
165+
166+ def _encode_actions (self , rollouts ):
167+ num_steps , n_processes , _ = rollouts .actions .shape
157168 # ACTION
158169 # encode
159170 # n_steps x n_process x 1
160171 actions = rollouts .actions
161172 actions = self .action_encoder (
162173 actions .reshape (- 1 )).reshape (num_steps , n_processes , - 1 )
163174
164- # condition = STATE + ADVANTAGE
165- conditions = torch .cat ([advantages , states , actions ], dim = - 1 )
175+ return actions
176+
177+ def compute_contrastive_loss (self , rollouts , encoded_advantages ):
178+ """
179+ Contrastive Predictive Coding for learning representation and density ratio
180+ :param rollouts: Storage's instance
181+ :param advantage: tensor of shape: (timestep, n_processes, 1)
182+ """
183+ # FIXME: only compatible with 1D observation
184+ num_steps , n_processes , _ = encoded_advantages .shape
185+
186+ # encoded all the input
187+ encoded_input_seq = self ._encode_input_sequences (rollouts )
188+ encoded_advantages = self ._encode_advantages (encoded_advantages )
189+ encoded_states = self ._encode_states (rollouts )
190+ encoded_actions = self ._encode_actions (rollouts )
191+
192+ # condition = STATE + ADVANTAGE + ACTIONS
193+ conditions = torch .cat (
194+ [encoded_advantages , encoded_states , encoded_actions ], dim = - 1 )
166195 # reshape to n_steps x hidden_dim x n_processes
167196 conditions = conditions .permute (0 , 2 , 1 )
168197
169198 # compute nce
170199 contrastive_loss = 0
171200 correct = 0
172201 for i in range (num_steps ):
173- density_ratio = torch .mm (input_seq [i ], conditions [i ])
202+ # f(Z, s0, a0, R) WITHOUT exponential
203+ f_value = torch .mm (encoded_input_seq [i ], conditions [i ])
174204 # accuracy
175205 correct += torch .sum (torch .eq (torch .argmax (self .softmax (
176- density_ratio ), dim = 1 ), torch .arange (0 , n_processes ).to (self .device )))
206+ f_value ), dim = 1 ), torch .arange (0 , n_processes ).to (self .device )))
177207 # nce
178208 contrastive_loss += torch .sum (
179- torch .diag (self .log_softmax (density_ratio )))
209+ torch .diag (self .log_softmax (f_value )))
180210
181211 # log loss
182212 contrastive_loss /= - 1 * n_processes * num_steps
@@ -193,76 +223,14 @@ def compute_weighted_advantages(self, rollouts, advantages):
193223 # FIXME: only compatible with 1D observation
194224 num_steps , n_processes , _ = advantages .shape
195225
196- # INPUT SEQUENCES AND MASKS
197- # the stochastic input will be defined by last 2 scalar
198- input_seq = rollouts .obs [1 :, :, - 2 :]
199- masks = rollouts .masks [1 :].reshape (num_steps , n_processes )
200- # reverse the input seq order since we want to compute from right to left
201- input_seq = torch .flip (input_seq , [0 ])
202- masks = torch .flip (masks , [0 ])
203- # encode the input sequence
204- # Let's figure out which steps in the sequence have a zero for any agent
205- # We will always assume t=0 has a zero in it as that makes the logic cleaner
206- has_zeros = ((masks [1 :] == 0.0 )
207- .any (dim = - 1 )
208- .nonzero ()
209- .squeeze ()
210- .cpu ())
211-
212- # +1 to correct the masks[1:]
213- if has_zeros .dim () == 0 :
214- # Deal with scalar
215- has_zeros = [has_zeros .item () + 1 ]
216- else :
217- has_zeros = (has_zeros + 1 ).numpy ().tolist ()
218-
219- # add t=0 and t=T to the list
220- has_zeros = [- 1 ] + has_zeros + [num_steps - 1 ]
221-
222- outputs = []
223-
224- for i in range (len (has_zeros ) - 1 ):
225- # We can now process steps that don't have any zeros in masks together!
226- # This is much faster
227- start_idx = has_zeros [i ]
228- end_idx = has_zeros [i + 1 ]
229-
230- output , _ = self .input_seq_encoder (
231- input_seq [start_idx + 1 : end_idx + 1 ])
232-
233- outputs .append (output )
234-
235- # x is a (T, N, -1) tensor
236- input_seq = torch .cat (outputs , dim = 0 )
237- assert len (input_seq ) == num_steps
238- # reverse back
239- input_seq = torch .flip (input_seq , [0 ])
240-
241- # ADVANTAGES
242- # encode
243- # n_steps x n_process x hidden_dim/2
244- encoded_advantages = self .advantage_encoder (
245- advantages .reshape (- 1 , 1 )).reshape (num_steps , n_processes , - 1 )
246-
247- # STATES
248- # encode
249- # n_steps x n_process x hidden_dim/2
250- states = rollouts .obs [:- 1 ]
251- # FIXME: hard code for 1D env
252- states_shape = states .shape [2 :][0 ]
253- states = self .state_encoder (
254- states .reshape (- 1 , states_shape )).reshape (num_steps , n_processes , - 1 )
255-
256- # ACTION
257- # encode
258- # n_steps x n_process x 1
259- actions = rollouts .actions
260- actions = self .action_encoder (
261- actions .reshape (- 1 )).reshape (num_steps , n_processes , - 1 )
226+ input_seq = self ._encode_input_sequences (rollouts )
227+ encoded_advantages = self ._encode_advantages (advantages )
228+ encoded_states = self ._encode_states (rollouts )
229+ encoded_actions = self ._encode_actions (rollouts )
262230
263231 # condition = STATE + ADVANTAGE
264232 conditions = torch .cat (
265- [encoded_advantages , states , actions ], dim = - 1 )
233+ [encoded_advantages , encoded_states , encoded_actions ], dim = - 1 )
266234 # reshape to n_steps x hidden_dim x n_processes
267235 conditions = conditions .permute (0 , 2 , 1 )
268236
0 commit comments