Skip to content

Commit 5c2d8a8

Browse files
committed
Update
[ghstack-poisoned]
1 parent 952b7e7 commit 5c2d8a8

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

test/test_collector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,7 @@ def create_env():
15121512
cudagraph_policy=cudagraph,
15131513
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
15141514
)
1515+
assert "policy" in collector._weight_senders, collector._weight_senders.keys()
15151516
try:
15161517
# collect state_dict
15171518
state_dict = collector.state_dict()

torchrl/collectors/collectors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,19 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
307307
else None
308308
)
309309

310+
# If no weights were provided and a sync scheme exists, extract the latest
311+
# weights from the current model using the scheme strategy (state_dict or tensordict).
312+
# This ensures we don't return stale cached weights.
313+
if weights is None and scheme is not None:
314+
from torchrl.weight_update.weight_sync_schemes import (
315+
_resolve_model,
316+
WeightStrategy,
317+
)
318+
319+
strategy = WeightStrategy(extract_as=scheme.strategy)
320+
model = _resolve_model(self, model_id)
321+
return strategy.extract_weights(model)
322+
310323
if weights is None:
311324
if model_id == "policy" and hasattr(self, "policy_weights"):
312325
return self.policy_weights

torchrl/envs/batched_envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,7 +2492,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
24922492
# Set event before sending non-tensor data so parent knows worker is done
24932493
# The recv() call itself will provide synchronization for the pipe
24942494
mp_event.set()
2495-
2495+
24962496
if _non_tensor_keys:
24972497
child_pipe.send(
24982498
("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
@@ -2534,7 +2534,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
25342534
# Set event before sending non-tensor data so parent knows worker is done
25352535
# The recv() call itself will provide synchronization for the pipe
25362536
mp_event.set()
2537-
2537+
25382538
if _non_tensor_keys:
25392539
ntd = root_next_td.select(*_non_tensor_keys)
25402540
ntd.set("next", td_next.select(*_non_tensor_keys))

0 commit comments

Comments
 (0)