Skip to content

[BUG] ParallelEnv + aSyncDataCollector / MultiSyncDataCollector not working if replay_buffer is given #3240

@MathieuFonsProjects

Description

@MathieuFonsProjects

Describe the bug

If replay_buffer is given to the collector (for possible .start() use) with a parallelEnv instance, then it doesn't work (flattening is needed ?) Freeze at the first collected batch during the .extend call. But work with extend if called outside of the collector, without giving the replay buffer.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch, time
import torch.nn as nn
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from torchrl.envs import EnvCreator, ParallelEnv, GymWrapper, Transform, TransformedEnv, Compose, DTypeCastTransform, RewardScaling, RewardSum, StepCounter
from torchrl.collectors import aSyncDataCollector, MultiSyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, RandomSampler
from tensordict.nn import TensorDictModule

# Environment factory
def create_env(render = None):
    env = PendulumEnv("human")

    env = GymWrapper(env)

    env = TransformedEnv(
        env,
        Compose(DTypeCastTransform(torch.float64, torch.float32), 
                RewardScaling(-8., 8., "reward", "reward", True),
                RewardSum(in_keys=[("reward",)], out_keys=[("episode_reward",)]),
                StepCounter(256)
        )
    )

    return env

create_env_new = EnvCreator(create_env)

def parallel_env():
    return ParallelEnv(4, create_env_new)

if __name__ == '__main__':

    policy_net = nn.Linear(3, 1)
    policy = TensorDictModule(policy_net, in_keys=["observation"], out_keys=["action"]).to("cuda")

    replay_buffer = TensorDictReplayBuffer(
        storage=LazyTensorStorage(2e6, ndim=1), #tried with ndim = 1, ndim = 2, ndim = 3
        sampler=RandomSampler(),
        batch_size=128,
    )

    # Create async data collector
    collector = aSyncDataCollector( #Try aSyncDataCollector or MultiSyncDataCollector
        parallel_env, # work with create_env function but not this one
        policy,
        num_workers=1,
        frames_per_batch=64,
        total_frames=-1,
        extend_buffer=True,
        replay_buffer=replay_buffer,
        device=torch.device("cpu"),
        storing_device=torch.device("cpu"),
        env_device=torch.device("cpu"),
        policy_device=torch.device("cuda"),
    )

    # Doesn't work

    collector.start()

    while True:
        print(len(replay_buffer))
        time.sleep(2.)

    # Doesn't work if replay_buffer is given in the collector
    # for batch in collector:
    #     print(batch.shape)
    #     replay_buffer.extend(batch) This work if no replay buffer is given to the collector and used here instead
    #     print(len(replay_buffer))
    #     time.sleep(2.)
RuntimeError: indexed destination TensorDict batch size is torch.Size([4, 4]) (batch_size = torch.Size([2000000, 4]), index=tensor([0, 1, 2, 3])), which differs from the source batch size torch.Size([4, 16]

or

RuntimeError: expand_as_right requires the destination tensor to have less dimensions than the input tensor, got tensor.ndimension()=2 and dest.ndimension()=1

Expected behavior

Same behavior as when the replay buffer isn't given to the collector and extend manually.

Additional context

Found this in the doc in single node data collectors:
Using replay buffers that sample trajectories with MultiSyncDataCollector isn’t currently fully supported as the data batches can come from any worker and in most cases consecutive batches written in the buffer won’t come from the same source (thereby interrupting the trajectories).

But I guess this apply only with Multi'a'SyncDataCollector and there is a typo ? It doesn't explained why Sync and aSync - without Multi- wouldn't work ? Can be wrong and miss understood

Checklist

  • [*] I have checked that there is no similar issue in the repo (required)
  • [*] I have read the documentation (required)
  • [*] I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions