Skip to content

Commit 2b98c27

Browse files
Yanghan Wangfacebook-github-bot
authored andcommitted
make the offload function customizable for DatasetFromList
Summary: Pull Request resolved: #4626 Previously we use `serialize: bool` to control if we want to offload the `DatasetFromList` storage to numpy. This diff generalize the "serialize" to "offload", and make the "offload function" customizable so that we can switch between implementations. The setting of `offload function` is done by context manager in order to avoid passing this argument all the way down. Reviewed By: sstsai-adl Differential Revision: D40818736 fbshipit-source-id: ed1b47eea86546def6c06f78bc12d6edf267df28
1 parent c54429b commit 2b98c27

File tree

2 files changed

+95
-30
lines changed

2 files changed

+95
-30
lines changed

detectron2/data/common.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import contextlib
23
import copy
34
import itertools
45
import logging
56
import numpy as np
67
import pickle
78
import random
9+
from typing import Callable, Union
810
import torch.utils.data as data
911
from torch.utils.data.sampler import Sampler
1012

1113
from detectron2.utils.serialize import PicklableWrapper
1214

1315
__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
1416

17+
logger = logging.getLogger(__name__)
18+
1519

1620
def _shard_iterator_dataloader_worker(iterable):
1721
# Shard the iterable if we're currently inside pytorch dataloader worker.
@@ -106,56 +110,101 @@ def __getitem__(self, idx):
106110
)
107111

108112

113+
class NumpySerializedList(object):
114+
"""
115+
A list-like object whose items are serialized and stored in a Numpy Array. When
116+
forking a process that has NumpySerializedList, subprocesses can read the same list
117+
without triggering copy-on-access, therefore they will share RAM for the list. This
118+
avoids the issue in https://github.com/pytorch/pytorch/issues/13246
119+
"""
120+
121+
def __init__(self, lst: list):
122+
self._lst = lst
123+
124+
def _serialize(data):
125+
buffer = pickle.dumps(data, protocol=-1)
126+
return np.frombuffer(buffer, dtype=np.uint8)
127+
128+
logger.info(
129+
"Serializing {} elements to byte tensors and concatenating them all ...".format(
130+
len(self._lst)
131+
)
132+
)
133+
self._lst = [_serialize(x) for x in self._lst]
134+
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
135+
self._addr = np.cumsum(self._addr)
136+
self._lst = np.concatenate(self._lst)
137+
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
138+
139+
def __len__(self):
140+
return len(self._addr)
141+
142+
def __getitem__(self, idx):
143+
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
144+
end_addr = self._addr[idx].item()
145+
bytes = memoryview(self._lst[start_addr:end_addr])
146+
147+
# @lint-ignore PYTHONPICKLEISBAD
148+
return pickle.loads(bytes)
149+
150+
151+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = NumpySerializedList
152+
153+
154+
@contextlib.contextmanager
155+
def set_default_dataset_from_list_serialize_method(new):
156+
"""
157+
Context manager for using custom serialize function when creating DatasetFromList
158+
"""
159+
160+
global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
161+
orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
162+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new
163+
yield
164+
_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig
165+
166+
109167
class DatasetFromList(data.Dataset):
110168
"""
111169
Wrap a list to a torch Dataset. It produces elements of the list as data.
112170
"""
113171

114-
def __init__(self, lst: list, copy: bool = True, serialize: bool = True):
172+
def __init__(
173+
self,
174+
lst: list,
175+
copy: bool = True,
176+
serialize: Union[bool, Callable] = True,
177+
):
115178
"""
116179
Args:
117180
lst (list): a list which contains elements to produce.
118181
copy (bool): whether to deepcopy the element when producing it,
119182
so that the result can be modified in place without affecting the
120183
source in the list.
121-
serialize (bool): whether to hold memory using serialized objects, when
122-
enabled, data loader workers can use shared RAM from master
123-
process instead of making a copy.
184+
serialize (bool or callable): whether to serialize the stroage to other
185+
backend. If `True`, the default serialize method will be used, if given
186+
a callable, the callable will be used as serialize method.
124187
"""
125188
self._lst = lst
126189
self._copy = copy
127-
self._serialize = serialize
128-
129-
def _serialize(data):
130-
buffer = pickle.dumps(data, protocol=-1)
131-
return np.frombuffer(buffer, dtype=np.uint8)
190+
if not isinstance(serialize, (bool, Callable)):
191+
raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}")
192+
self._serialize = serialize is not False
132193

133194
if self._serialize:
134-
logger = logging.getLogger(__name__)
135-
logger.info(
136-
"Serializing {} elements to byte tensors and concatenating them all ...".format(
137-
len(self._lst)
138-
)
195+
serialize_method = (
196+
serialize
197+
if isinstance(serialize, Callable)
198+
else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
139199
)
140-
self._lst = [_serialize(x) for x in self._lst]
141-
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
142-
self._addr = np.cumsum(self._addr)
143-
self._lst = np.concatenate(self._lst)
144-
logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
200+
logger.info(f"Serializing the dataset using: {serialize_method}")
201+
self._lst = serialize_method(self._lst)
145202

146203
def __len__(self):
147-
if self._serialize:
148-
return len(self._addr)
149-
else:
150-
return len(self._lst)
204+
return len(self._lst)
151205

152206
def __getitem__(self, idx):
153-
if self._serialize:
154-
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
155-
end_addr = self._addr[idx].item()
156-
bytes = memoryview(self._lst[start_addr:end_addr])
157-
return pickle.loads(bytes)
158-
elif self._copy:
207+
if self._copy and not self._serialize:
159208
return copy.deepcopy(self._lst[idx])
160209
else:
161210
return self._lst[idx]

tests/data/test_dataset.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
build_detection_test_loader,
2020
build_detection_train_loader,
2121
)
22-
from detectron2.data.common import AspectRatioGroupedDataset
22+
from detectron2.data.common import (
23+
AspectRatioGroupedDataset,
24+
set_default_dataset_from_list_serialize_method,
25+
)
2326
from detectron2.data.samplers import InferenceSampler, TrainingSampler
2427

2528

@@ -41,6 +44,19 @@ def test_using_lazy_path(self):
4144
self.assertTrue(isinstance(path, LazyPath))
4245
self.assertEqual(os.fspath(path), _a_slow_func(i))
4346

47+
def test_alternative_serialize_method(self):
48+
dataset = [1, 2, 3]
49+
dataset = DatasetFromList(dataset, serialize=torch.tensor)
50+
self.assertEqual(dataset[2], torch.tensor(3))
51+
52+
def test_change_default_serialize_method(self):
53+
dataset = [1, 2, 3]
54+
with set_default_dataset_from_list_serialize_method(torch.tensor):
55+
dataset_1 = DatasetFromList(dataset, serialize=True)
56+
self.assertEqual(dataset_1[2], torch.tensor(3))
57+
dataset_2 = DatasetFromList(dataset, serialize=True)
58+
self.assertEqual(dataset_2[2], 3)
59+
4460

4561
class TestMapDataset(unittest.TestCase):
4662
@staticmethod

0 commit comments

Comments
 (0)