Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,186 @@ def env_fn():
total_frames=frames_per_batch * 100,
)

class FixedIDEnv(EnvBase):
"""
A simple mock environment that returns a fixed ID as its sole observation.

This environment is designed to test MultiSyncDataCollector ordering.
Each environment instance is initialized with a unique env_id, which it
returns as the observation at every step.
"""

def __init__(
self,
env_id: int,
max_steps: int = 10,
sleep_odd_only: bool = False,
**kwargs,
):
"""
Args:
env_id: The ID to return as observation. This will be returned as a tensor.
max_steps: Maximum number of steps before the environment terminates.
"""
super().__init__(device="cpu", batch_size=torch.Size([]))
self.env_id = env_id
self.max_steps = max_steps
self.sleep_odd_only = sleep_odd_only
self._step_count = 0

# Define specs
self.observation_spec = Composite(
observation=Unbounded(shape=(1,), dtype=torch.float32)
)
self.action_spec = Composite(
action=Unbounded(shape=(1,), dtype=torch.float32)
)
self.reward_spec = Composite(
reward=Unbounded(shape=(1,), dtype=torch.float32)
)
self.done_spec = Composite(
done=Unbounded(shape=(1,), dtype=torch.bool),
terminated=Unbounded(shape=(1,), dtype=torch.bool),
truncated=Unbounded(shape=(1,), dtype=torch.bool),
)

def _reset(self, tensordict: TensorDict | None = None, **kwargs) -> TensorDict:
"""Reset the environment and return initial observation."""
# Add sleep to simulate real-world timing variations
# This helps test that the collector properly handles different reset times
if not self.sleep_odd_only:
# Random sleep up to 10ms
time.sleep(torch.rand(1).item() * 0.01)
elif self.env_id % 2 == 1:
time.sleep(1)

self._step_count = 0
return TensorDict(
{
"observation": torch.tensor(
[float(self.env_id)], dtype=torch.float32
),
"done": torch.tensor([False], dtype=torch.bool),
"terminated": torch.tensor([False], dtype=torch.bool),
"truncated": torch.tensor([False], dtype=torch.bool),
},
batch_size=self.batch_size,
)

def _step(self, tensordict: TensorDict) -> TensorDict:
"""Execute one step and return the env_id as observation."""
self._step_count += 1
done = self._step_count >= self.max_steps

if self.sleep_odd_only and self.env_id % 2 == 1:
time.sleep(1)

return TensorDict(
{
"observation": torch.tensor(
[float(self.env_id)], dtype=torch.float32
),
"reward": torch.tensor([1.0], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"terminated": torch.tensor([done], dtype=torch.bool),
"truncated": torch.tensor([False], dtype=torch.bool),
},
batch_size=self.batch_size,
)

def _set_seed(self, seed: int | None) -> int | None:
"""Set the seed for reproducibility."""
if seed is not None:
torch.manual_seed(seed)
return seed

@pytest.mark.parametrize("num_envs,n_steps", [(8, 5)])
@pytest.mark.parametrize("with_preempt", [False, True])
@pytest.mark.parametrize("cat_results", ["stack", -1])
Comment on lines +1808 to +1810
Copy link
Contributor Author

@LCarmi LCarmi Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the test suite much wider, including the preemption/no-preemption case and the stack/cat. On my machine, the preemption ones last ~10s while the non-preemt ones last ~4s. I am open to try to make them shorted (by tuning the time.sleep if you deem so.

def test_multi_sync_data_collector_ordering(
self, num_envs: int, n_steps: int, with_preempt: bool, cat_results: str | int
):
"""
Test that MultiSyncDataCollector returns data in the correct order.

We create num_envs environments, each returning its env_id as the observation.
After collection, we verify that the observations correspond to the correct env_ids in order
"""
if with_preempt and IS_OSX:
pytest.skip(
"Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform."
)

# Create environment factories using partial - one for each env_id
# This pattern mirrors CrossPlayEvaluator._rollout usage
env_factories = [
functools.partial(
self.FixedIDEnv, env_id=i, max_steps=10, sleep_odd_only=with_preempt
)
for i in range(num_envs)
]

# Initialize MultiSyncDataCollector
collector = MultiSyncDataCollector(
create_env_fn=env_factories,
frames_per_batch=num_envs * n_steps,
total_frames=num_envs * n_steps,
device="cpu",
preemptive_threshold=0.5 if with_preempt else None,
cat_results=cat_results,
init_random_frames=n_steps, # no need of a policy
use_buffers=True,
)

# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
# In the pre-emption case, we have that envs with odd ids are order of magnitude slower.
# These should be skipped by pre-emption (since they are the 50% slowest)

# Recover rectangular shape of batch to uniform checks
if cat_results != "stack":
if not with_preempt:
batch = batch.reshape(num_envs, n_steps)
else:
traj_ids = batch["collector", "traj_ids"]
traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0
# Split trajectories to recover correct shape
# thanks to having a single trajectory per env
# Pads with zeros!
batch = split_trajectories(
batch, trajectory_key=("collector", "traj_ids")
)
# Use -1 for padding to uniform with other preemption
is_padded = batch["collector", "traj_ids"] == 0
batch[is_padded] = -1

#
for env_idx in range(num_envs):
if with_preempt and env_idx % 2 == 1:
# This is a slow env, should have been preempted after first step
assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all()
continue
# This is a fast env, no preemption happened
assert (batch["collector", "traj_ids"][env_idx] != -1).all()

env_data = batch[env_idx]
observations = env_data["observation"]
# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"

collector.shutdown()


class TestCollectorDevices:
class DeviceLessEnv(EnvBase):
Expand Down
39 changes: 23 additions & 16 deletions torchrl/collectors/_multi_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
if cat_results is None:
cat_results = "stack"

self.buffers = {}
self.buffers = [None for _ in range(self.num_workers)]
dones = [False for _ in range(self.num_workers)]
workers_frames = [0 for _ in range(self.num_workers)]
same_device = None
Expand All @@ -236,7 +236,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
msg = "continue_random"
else:
msg = "continue"
# Debug: sending 'continue'
self.pipes[idx].send((None, msg))

self._iter += 1
Expand Down Expand Up @@ -299,16 +298,19 @@ def iterator(self) -> Iterator[TensorDictBase]:
if preempt:
# mask buffers if cat, and create a mask if stack
if cat_results != "stack":
buffers = {}
for worker_idx, buffer in self.buffers.items():
buffers = [None] * self.num_workers
for worker_idx, buffer in enumerate(self.buffers):
# Skip pre-empted envs:
if buffer is None:
continue
valid = buffer.get(("collector", "traj_ids")) != -1
if valid.ndim > 2:
valid = valid.flatten(0, -2)
if valid.ndim == 2:
valid = valid.any(0)
buffers[worker_idx] = buffer[..., valid]
else:
for buffer in self.buffers.values():
for buffer in filter(lambda x: x is not None, self.buffers):
with buffer.unlock_():
buffer.set(
("collector", "mask"),
Expand All @@ -320,7 +322,7 @@ def iterator(self) -> Iterator[TensorDictBase]:

# Skip frame counting if this worker didn't send data this iteration
# (happens when reusing buffers or on first iteration with some workers)
if idx not in buffers:
if self.buffers[idx] is None:
continue

workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()
Expand All @@ -331,18 +333,18 @@ def iterator(self) -> Iterator[TensorDictBase]:
if self.replay_buffer is not None:
yield
self._frames += sum(
[
self.frames_per_batch_worker(worker_idx=worker_idx)
for worker_idx in range(self.num_workers)
]
self.frames_per_batch_worker(worker_idx)
for worker_idx in range(self.num_workers)
)
continue

# we have to correct the traj_ids to make sure that they don't overlap
# We can count the number of frames collected for free in this loop
n_collected = 0
for idx in buffers.keys():
for idx in range(self.num_workers):
buffer = buffers[idx]
if buffer is None:
continue
traj_ids = buffer.get(("collector", "traj_ids"))
if preempt:
if cat_results == "stack":
Expand All @@ -356,7 +358,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
if same_device is None:
prev_device = None
same_device = True
for item in self.buffers.values():
for item in filter(lambda x: x is not None, self.buffers):
if prev_device is None:
prev_device = item.device
else:
Expand All @@ -367,10 +369,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
)
if same_device:
self.out_buffer = stack(list(buffers.values()), 0)
self.out_buffer = stack(
[item for item in buffers if item is not None], 0
)
else:
self.out_buffer = stack(
[item.cpu() for item in buffers.values()], 0
[item.cpu() for item in buffers if item is not None], 0
)
else:
if self._use_buffers is None:
Expand All @@ -383,10 +387,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
)
try:
if same_device:
self.out_buffer = torch.cat(list(buffers.values()), cat_results)
self.out_buffer = torch.cat(
[item for item in buffers if item is not None], cat_results
)
else:
self.out_buffer = torch.cat(
[item.cpu() for item in buffers.values()], cat_results
[item.cpu() for item in buffers if item is not None],
cat_results,
)
except RuntimeError as err:
if (
Expand Down
Loading