@@ -94,14 +94,48 @@ def compute_contrastive_loss(self, rollouts, advantages):
9494 # FIXME: only compatible with 1D observation
9595 num_steps , n_processes , _ = advantages .shape
9696
97- # INPUT SEQUENCES
97+ # INPUT SEQUENCES AND MASKS
9898 # the stochastic input will be defined by last 2 scalar
9999 input_seq = rollouts .obs [1 :, :, - 2 :]
100+ masks = rollouts .masks [1 :].reshape (num_steps , n_processes )
100101 # reverse the input seq order since we want to compute from right to left
101102 input_seq = torch .flip (input_seq , [0 ])
103+ masks = torch .flip (masks , [0 ])
102104 # encode the input sequence
103- # n_steps x n_processes x hidden_dim
104- input_seq , _ = self .input_seq_encoder (input_seq )
105+ # Let's figure out which steps in the sequence have a zero for any agent
106+ # We will always assume t=0 has a zero in it as that makes the logic cleaner
107+ has_zeros = ((masks [1 :] == 0.0 )
108+ .any (dim = - 1 )
109+ .nonzero ()
110+ .squeeze ()
111+ .cpu ())
112+
113+ # +1 to correct the masks[1:]
114+ if has_zeros .dim () == 0 :
115+ # Deal with scalar
116+ has_zeros = [has_zeros .item () + 1 ]
117+ else :
118+ has_zeros = (has_zeros + 1 ).numpy ().tolist ()
119+
120+ # add t=0 and t=T to the list
121+ has_zeros = [- 1 ] + has_zeros + [num_steps - 1 ]
122+
123+ outputs = []
124+
125+ for i in range (len (has_zeros ) - 1 ):
126+ # We can now process steps that don't have any zeros in masks together!
127+ # This is much faster
128+ start_idx = has_zeros [i ]
129+ end_idx = has_zeros [i + 1 ]
130+
131+ output , _ = self .input_seq_encoder (
132+ input_seq [start_idx + 1 : end_idx + 1 ])
133+
134+ outputs .append (output )
135+
136+ # x is a (T, N, -1) tensor
137+ input_seq = torch .cat (outputs , dim = 0 )
138+ assert len (input_seq ) == num_steps
105139 # reverse back
106140 input_seq = torch .flip (input_seq , [0 ])
107141
@@ -159,14 +193,48 @@ def compute_weighted_advantages(self, rollouts, advantages):
159193 # FIXME: only compatible with 1D observation
160194 num_steps , n_processes , _ = advantages .shape
161195
162- # INPUT SEQUENCES
196+ # INPUT SEQUENCES AND MASKS
163197 # the stochastic input will be defined by last 2 scalar
164198 input_seq = rollouts .obs [1 :, :, - 2 :]
199+ masks = rollouts .masks [1 :].reshape (num_steps , n_processes )
165200 # reverse the input seq order since we want to compute from right to left
166201 input_seq = torch .flip (input_seq , [0 ])
202+ masks = torch .flip (masks , [0 ])
167203 # encode the input sequence
168- # output shape: n_steps x n_processes x hidden_dim
169- input_seq , _ = self .input_seq_encoder (input_seq )
204+ # Let's figure out which steps in the sequence have a zero for any agent
205+ # We will always assume t=0 has a zero in it as that makes the logic cleaner
206+ has_zeros = ((masks [1 :] == 0.0 )
207+ .any (dim = - 1 )
208+ .nonzero ()
209+ .squeeze ()
210+ .cpu ())
211+
212+ # +1 to correct the masks[1:]
213+ if has_zeros .dim () == 0 :
214+ # Deal with scalar
215+ has_zeros = [has_zeros .item () + 1 ]
216+ else :
217+ has_zeros = (has_zeros + 1 ).numpy ().tolist ()
218+
219+ # add t=0 and t=T to the list
220+ has_zeros = [- 1 ] + has_zeros + [num_steps - 1 ]
221+
222+ outputs = []
223+
224+ for i in range (len (has_zeros ) - 1 ):
225+ # We can now process steps that don't have any zeros in masks together!
226+ # This is much faster
227+ start_idx = has_zeros [i ]
228+ end_idx = has_zeros [i + 1 ]
229+
230+ output , _ = self .input_seq_encoder (
231+ input_seq [start_idx + 1 : end_idx + 1 ])
232+
233+ outputs .append (output )
234+
235+ # x is a (T, N, -1) tensor
236+ input_seq = torch .cat (outputs , dim = 0 )
237+ assert len (input_seq ) == num_steps
170238 # reverse back
171239 input_seq = torch .flip (input_seq , [0 ])
172240
0 commit comments