Skip to content

Commit 14daf9f

Browse files
committed
fixed reset hxs state when episode is finished
1 parent d5ea6e2 commit 14daf9f

File tree

2 files changed

+74
-224
lines changed

2 files changed

+74
-224
lines changed

core/algorithms/lacie/base_lacie.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,48 @@ def compute_contrastive_loss(self, rollouts, advantages):
9494
# FIXME: only compatible with 1D observation
9595
num_steps, n_processes, _ = advantages.shape
9696

97-
# INPUT SEQUENCES
97+
# INPUT SEQUENCES AND MASKS
9898
# the stochastic input will be defined by last 2 scalar
9999
input_seq = rollouts.obs[1:, :, -2:]
100+
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
100101
# reverse the input seq order since we want to compute from right to left
101102
input_seq = torch.flip(input_seq, [0])
103+
masks = torch.flip(masks, [0])
102104
# encode the input sequence
103-
# n_steps x n_processes x hidden_dim
104-
input_seq, _ = self.input_seq_encoder(input_seq)
105+
# Let's figure out which steps in the sequence have a zero for any agent
106+
# We will always assume t=0 has a zero in it as that makes the logic cleaner
107+
has_zeros = ((masks[1:] == 0.0)
108+
.any(dim=-1)
109+
.nonzero()
110+
.squeeze()
111+
.cpu())
112+
113+
# +1 to correct the masks[1:]
114+
if has_zeros.dim() == 0:
115+
# Deal with scalar
116+
has_zeros = [has_zeros.item() + 1]
117+
else:
118+
has_zeros = (has_zeros + 1).numpy().tolist()
119+
120+
# add t=0 and t=T to the list
121+
has_zeros = [-1] + has_zeros + [num_steps - 1]
122+
123+
outputs = []
124+
125+
for i in range(len(has_zeros) - 1):
126+
# We can now process steps that don't have any zeros in masks together!
127+
# This is much faster
128+
start_idx = has_zeros[i]
129+
end_idx = has_zeros[i + 1]
130+
131+
output, _ = self.input_seq_encoder(
132+
input_seq[start_idx + 1: end_idx + 1])
133+
134+
outputs.append(output)
135+
136+
# x is a (T, N, -1) tensor
137+
input_seq = torch.cat(outputs, dim=0)
138+
assert len(input_seq) == num_steps
105139
# reverse back
106140
input_seq = torch.flip(input_seq, [0])
107141

@@ -159,14 +193,48 @@ def compute_weighted_advantages(self, rollouts, advantages):
159193
# FIXME: only compatible with 1D observation
160194
num_steps, n_processes, _ = advantages.shape
161195

162-
# INPUT SEQUENCES
196+
# INPUT SEQUENCES AND MASKS
163197
# the stochastic input will be defined by last 2 scalar
164198
input_seq = rollouts.obs[1:, :, -2:]
199+
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
165200
# reverse the input seq order since we want to compute from right to left
166201
input_seq = torch.flip(input_seq, [0])
202+
masks = torch.flip(masks, [0])
167203
# encode the input sequence
168-
# output shape: n_steps x n_processes x hidden_dim
169-
input_seq, _ = self.input_seq_encoder(input_seq)
204+
# Let's figure out which steps in the sequence have a zero for any agent
205+
# We will always assume t=0 has a zero in it as that makes the logic cleaner
206+
has_zeros = ((masks[1:] == 0.0)
207+
.any(dim=-1)
208+
.nonzero()
209+
.squeeze()
210+
.cpu())
211+
212+
# +1 to correct the masks[1:]
213+
if has_zeros.dim() == 0:
214+
# Deal with scalar
215+
has_zeros = [has_zeros.item() + 1]
216+
else:
217+
has_zeros = (has_zeros + 1).numpy().tolist()
218+
219+
# add t=0 and t=T to the list
220+
has_zeros = [-1] + has_zeros + [num_steps - 1]
221+
222+
outputs = []
223+
224+
for i in range(len(has_zeros) - 1):
225+
# We can now process steps that don't have any zeros in masks together!
226+
# This is much faster
227+
start_idx = has_zeros[i]
228+
end_idx = has_zeros[i + 1]
229+
230+
output, _ = self.input_seq_encoder(
231+
input_seq[start_idx + 1: end_idx + 1])
232+
233+
outputs.append(output)
234+
235+
# x is a (T, N, -1) tensor
236+
input_seq = torch.cat(outputs, dim=0)
237+
assert len(input_seq) == num_steps
170238
# reverse back
171239
input_seq = torch.flip(input_seq, [0])
172240

train_imitation_learning.py

Lines changed: 0 additions & 218 deletions
This file was deleted.

0 commit comments

Comments
 (0)