Skip to content

Commit 293ab82

Browse files
committed
fixed has_zeros input in lacie + refactor code
1 parent 14daf9f commit 293ab82

File tree

1 file changed

+52
-84
lines changed

1 file changed

+52
-84
lines changed

core/algorithms/lacie/base_lacie.py

Lines changed: 52 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,8 @@ def __init__(self,
8585
self.softmax = nn.Softmax(dim=0)
8686
self.log_softmax = nn.LogSoftmax(dim=0)
8787

88-
def compute_contrastive_loss(self, rollouts, advantages):
89-
"""
90-
Contrastive Predictive Coding for learning representation and density ratio
91-
:param rollouts: Storage's instance
92-
:param advantage: tensor of shape: (timestep, n_processes, 1)
93-
"""
94-
# FIXME: only compatible with 1D observation
95-
num_steps, n_processes, _ = advantages.shape
96-
88+
def _encode_input_sequences(self, rollouts):
89+
num_steps, n_processes, _ = rollouts.actions.shape
9790
# INPUT SEQUENCES AND MASKS
9891
# the stochastic input will be defined by last 2 scalar
9992
input_seq = rollouts.obs[1:, :, -2:]
@@ -104,7 +97,7 @@ def compute_contrastive_loss(self, rollouts, advantages):
10497
# encode the input sequence
10598
# Let's figure out which steps in the sequence have a zero for any agent
10699
# We will always assume t=0 has a zero in it as that makes the logic cleaner
107-
has_zeros = ((masks[1:] == 0.0)
100+
has_zeros = ((masks[:-1] == 0.0)
108101
.any(dim=-1)
109102
.nonzero()
110103
.squeeze()
@@ -128,8 +121,13 @@ def compute_contrastive_loss(self, rollouts, advantages):
128121
start_idx = has_zeros[i]
129122
end_idx = has_zeros[i + 1]
130123

131-
output, _ = self.input_seq_encoder(
132-
input_seq[start_idx + 1: end_idx + 1])
124+
try:
125+
output, _ = self.input_seq_encoder(
126+
input_seq[start_idx + 1: end_idx + 1])
127+
except:
128+
print(start_idx)
129+
print(end_idx)
130+
print(has_zeros)
133131

134132
outputs.append(output)
135133

@@ -139,12 +137,21 @@ def compute_contrastive_loss(self, rollouts, advantages):
139137
# reverse back
140138
input_seq = torch.flip(input_seq, [0])
141139

140+
return input_seq
141+
142+
def _encode_advantages(self, advantages):
143+
# FIXME: only compatible with 1D observation
144+
num_steps, n_processes, _ = advantages.shape
142145
# ADVANTAGES
143146
# encode
144147
# n_steps x n_process x hidden_dim/2
145148
advantages = self.advantage_encoder(
146149
advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1)
147150

151+
return advantages
152+
153+
def _encode_states(self, rollouts):
154+
num_steps, n_processes, _ = rollouts.actions.shape
148155
# STATES
149156
# encode
150157
# n_steps x n_process x hidden_dim/2
@@ -154,29 +161,52 @@ def compute_contrastive_loss(self, rollouts, advantages):
154161
states = self.state_encoder(
155162
states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1)
156163

164+
return states
165+
166+
def _encode_actions(self, rollouts):
167+
num_steps, n_processes, _ = rollouts.actions.shape
157168
# ACTION
158169
# encode
159170
# n_steps x n_process x 1
160171
actions = rollouts.actions
161172
actions = self.action_encoder(
162173
actions.reshape(-1)).reshape(num_steps, n_processes, -1)
163174

164-
# condition = STATE + ADVANTAGE
165-
conditions = torch.cat([advantages, states, actions], dim=-1)
175+
return actions
176+
177+
def compute_contrastive_loss(self, rollouts, encoded_advantages):
178+
"""
179+
Contrastive Predictive Coding for learning representation and density ratio
180+
:param rollouts: Storage's instance
181+
:param advantage: tensor of shape: (timestep, n_processes, 1)
182+
"""
183+
# FIXME: only compatible with 1D observation
184+
num_steps, n_processes, _ = encoded_advantages.shape
185+
186+
# encoded all the input
187+
encoded_input_seq = self._encode_input_sequences(rollouts)
188+
encoded_advantages = self._encode_advantages(encoded_advantages)
189+
encoded_states = self._encode_states(rollouts)
190+
encoded_actions = self._encode_actions(rollouts)
191+
192+
# condition = STATE + ADVANTAGE + ACTIONS
193+
conditions = torch.cat(
194+
[encoded_advantages, encoded_states, encoded_actions], dim=-1)
166195
# reshape to n_steps x hidden_dim x n_processes
167196
conditions = conditions.permute(0, 2, 1)
168197

169198
# compute nce
170199
contrastive_loss = 0
171200
correct = 0
172201
for i in range(num_steps):
173-
density_ratio = torch.mm(input_seq[i], conditions[i])
202+
# f(Z, s0, a0, R) WITHOUT exponential
203+
f_value = torch.mm(encoded_input_seq[i], conditions[i])
174204
# accuracy
175205
correct += torch.sum(torch.eq(torch.argmax(self.softmax(
176-
density_ratio), dim=1), torch.arange(0, n_processes).to(self.device)))
206+
f_value), dim=1), torch.arange(0, n_processes).to(self.device)))
177207
# nce
178208
contrastive_loss += torch.sum(
179-
torch.diag(self.log_softmax(density_ratio)))
209+
torch.diag(self.log_softmax(f_value)))
180210

181211
# log loss
182212
contrastive_loss /= -1*n_processes*num_steps
@@ -193,76 +223,14 @@ def compute_weighted_advantages(self, rollouts, advantages):
193223
# FIXME: only compatible with 1D observation
194224
num_steps, n_processes, _ = advantages.shape
195225

196-
# INPUT SEQUENCES AND MASKS
197-
# the stochastic input will be defined by last 2 scalar
198-
input_seq = rollouts.obs[1:, :, -2:]
199-
masks = rollouts.masks[1:].reshape(num_steps, n_processes)
200-
# reverse the input seq order since we want to compute from right to left
201-
input_seq = torch.flip(input_seq, [0])
202-
masks = torch.flip(masks, [0])
203-
# encode the input sequence
204-
# Let's figure out which steps in the sequence have a zero for any agent
205-
# We will always assume t=0 has a zero in it as that makes the logic cleaner
206-
has_zeros = ((masks[1:] == 0.0)
207-
.any(dim=-1)
208-
.nonzero()
209-
.squeeze()
210-
.cpu())
211-
212-
# +1 to correct the masks[1:]
213-
if has_zeros.dim() == 0:
214-
# Deal with scalar
215-
has_zeros = [has_zeros.item() + 1]
216-
else:
217-
has_zeros = (has_zeros + 1).numpy().tolist()
218-
219-
# add t=0 and t=T to the list
220-
has_zeros = [-1] + has_zeros + [num_steps - 1]
221-
222-
outputs = []
223-
224-
for i in range(len(has_zeros) - 1):
225-
# We can now process steps that don't have any zeros in masks together!
226-
# This is much faster
227-
start_idx = has_zeros[i]
228-
end_idx = has_zeros[i + 1]
229-
230-
output, _ = self.input_seq_encoder(
231-
input_seq[start_idx + 1: end_idx + 1])
232-
233-
outputs.append(output)
234-
235-
# x is a (T, N, -1) tensor
236-
input_seq = torch.cat(outputs, dim=0)
237-
assert len(input_seq) == num_steps
238-
# reverse back
239-
input_seq = torch.flip(input_seq, [0])
240-
241-
# ADVANTAGES
242-
# encode
243-
# n_steps x n_process x hidden_dim/2
244-
encoded_advantages = self.advantage_encoder(
245-
advantages.reshape(-1, 1)).reshape(num_steps, n_processes, -1)
246-
247-
# STATES
248-
# encode
249-
# n_steps x n_process x hidden_dim/2
250-
states = rollouts.obs[:-1]
251-
# FIXME: hard code for 1D env
252-
states_shape = states.shape[2:][0]
253-
states = self.state_encoder(
254-
states.reshape(-1, states_shape)).reshape(num_steps, n_processes, -1)
255-
256-
# ACTION
257-
# encode
258-
# n_steps x n_process x 1
259-
actions = rollouts.actions
260-
actions = self.action_encoder(
261-
actions.reshape(-1)).reshape(num_steps, n_processes, -1)
226+
input_seq = self._encode_input_sequences(rollouts)
227+
encoded_advantages = self._encode_advantages(advantages)
228+
encoded_states = self._encode_states(rollouts)
229+
encoded_actions = self._encode_actions(rollouts)
262230

263231
# condition = STATE + ADVANTAGE
264232
conditions = torch.cat(
265-
[encoded_advantages, states, actions], dim=-1)
233+
[encoded_advantages, encoded_states, encoded_actions], dim=-1)
266234
# reshape to n_steps x hidden_dim x n_processes
267235
conditions = conditions.permute(0, 2, 1)
268236

0 commit comments

Comments
 (0)