Skip to content

Commit 4cbefba

Browse files
author
Louis Faury
committed
Composite specs can create named tensors with 'zero' and 'rand'
1 parent e7ec9c3 commit 4cbefba

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

test/test_specs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4585,6 +4585,26 @@ def test_names_repr(self):
45854585
assert "Composite" in repr_str
45864586
assert "obs" in repr_str
45874587

4588+
def test_zero_create_names(self):
4589+
"""Test that creating tensors with 'zero' propagates names."""
4590+
spec = Composite(
4591+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4592+
shape=(10,),
4593+
names=["batch"],
4594+
)
4595+
td = spec.zero()
4596+
td.names = ["batch"]
4597+
4598+
def test_rand_create_names(self):
4599+
"""Test that creating tensors with 'rand' propagates names."""
4600+
spec = Composite(
4601+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4602+
shape=(10,),
4603+
names=["batch"],
4604+
)
4605+
td = spec.rand()
4606+
td.names = ["batch"]
4607+
45884608

45894609
if __name__ == "__main__":
45904610
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/tensor_specs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5750,6 +5750,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57505750
_dict,
57515751
batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
57525752
device=self.device,
5753+
names=self.names,
57535754
)
57545755

57555756
def keys(
@@ -6030,6 +6031,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60306031
},
60316032
batch_size=_size([*shape, *self._safe_shape]),
60326033
device=device,
6034+
names=self.names,
60336035
)
60346036

60356037
def __eq__(self, other: object) -> bool:

0 commit comments

Comments
 (0)