|
1 | 1 | # Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 | +import contextlib |
2 | 3 | import copy
|
3 | 4 | import itertools
|
4 | 5 | import logging
|
5 | 6 | import numpy as np
|
6 | 7 | import pickle
|
7 | 8 | import random
|
| 9 | +from typing import Callable, Union |
8 | 10 | import torch.utils.data as data
|
9 | 11 | from torch.utils.data.sampler import Sampler
|
10 | 12 |
|
11 | 13 | from detectron2.utils.serialize import PicklableWrapper
|
12 | 14 |
|
13 | 15 | __all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
|
14 | 16 |
|
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
15 | 19 |
|
16 | 20 | def _shard_iterator_dataloader_worker(iterable):
|
17 | 21 | # Shard the iterable if we're currently inside pytorch dataloader worker.
|
@@ -106,56 +110,101 @@ def __getitem__(self, idx):
|
106 | 110 | )
|
107 | 111 |
|
108 | 112 |
|
| 113 | +class NumpySerializedList(object): |
| 114 | + """ |
| 115 | + A list-like object whose items are serialized and stored in a Numpy Array. When |
| 116 | + forking a process that has NumpySerializedList, subprocesses can read the same list |
| 117 | + without triggering copy-on-access, therefore they will share RAM for the list. This |
| 118 | + avoids the issue in https://github.com/pytorch/pytorch/issues/13246 |
| 119 | + """ |
| 120 | + |
| 121 | + def __init__(self, lst: list): |
| 122 | + self._lst = lst |
| 123 | + |
| 124 | + def _serialize(data): |
| 125 | + buffer = pickle.dumps(data, protocol=-1) |
| 126 | + return np.frombuffer(buffer, dtype=np.uint8) |
| 127 | + |
| 128 | + logger.info( |
| 129 | + "Serializing {} elements to byte tensors and concatenating them all ...".format( |
| 130 | + len(self._lst) |
| 131 | + ) |
| 132 | + ) |
| 133 | + self._lst = [_serialize(x) for x in self._lst] |
| 134 | + self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) |
| 135 | + self._addr = np.cumsum(self._addr) |
| 136 | + self._lst = np.concatenate(self._lst) |
| 137 | + logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) |
| 138 | + |
| 139 | + def __len__(self): |
| 140 | + return len(self._addr) |
| 141 | + |
| 142 | + def __getitem__(self, idx): |
| 143 | + start_addr = 0 if idx == 0 else self._addr[idx - 1].item() |
| 144 | + end_addr = self._addr[idx].item() |
| 145 | + bytes = memoryview(self._lst[start_addr:end_addr]) |
| 146 | + |
| 147 | + # @lint-ignore PYTHONPICKLEISBAD |
| 148 | + return pickle.loads(bytes) |
| 149 | + |
| 150 | + |
| 151 | +_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = NumpySerializedList |
| 152 | + |
| 153 | + |
| 154 | +@contextlib.contextmanager |
| 155 | +def set_default_dataset_from_list_serialize_method(new): |
| 156 | + """ |
| 157 | + Context manager for using custom serialize function when creating DatasetFromList |
| 158 | + """ |
| 159 | + |
| 160 | + global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
| 161 | + orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
| 162 | + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new |
| 163 | + yield |
| 164 | + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig |
| 165 | + |
| 166 | + |
109 | 167 | class DatasetFromList(data.Dataset):
|
110 | 168 | """
|
111 | 169 | Wrap a list to a torch Dataset. It produces elements of the list as data.
|
112 | 170 | """
|
113 | 171 |
|
114 |
| - def __init__(self, lst: list, copy: bool = True, serialize: bool = True): |
| 172 | + def __init__( |
| 173 | + self, |
| 174 | + lst: list, |
| 175 | + copy: bool = True, |
| 176 | + serialize: Union[bool, Callable] = True, |
| 177 | + ): |
115 | 178 | """
|
116 | 179 | Args:
|
117 | 180 | lst (list): a list which contains elements to produce.
|
118 | 181 | copy (bool): whether to deepcopy the element when producing it,
|
119 | 182 | so that the result can be modified in place without affecting the
|
120 | 183 | source in the list.
|
121 |
| - serialize (bool): whether to hold memory using serialized objects, when |
122 |
| - enabled, data loader workers can use shared RAM from master |
123 |
| - process instead of making a copy. |
| 184 | + serialize (bool or callable): whether to serialize the stroage to other |
| 185 | + backend. If `True`, the default serialize method will be used, if given |
| 186 | + a callable, the callable will be used as serialize method. |
124 | 187 | """
|
125 | 188 | self._lst = lst
|
126 | 189 | self._copy = copy
|
127 |
| - self._serialize = serialize |
128 |
| - |
129 |
| - def _serialize(data): |
130 |
| - buffer = pickle.dumps(data, protocol=-1) |
131 |
| - return np.frombuffer(buffer, dtype=np.uint8) |
| 190 | + if not isinstance(serialize, (bool, Callable)): |
| 191 | + raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}") |
| 192 | + self._serialize = serialize is not False |
132 | 193 |
|
133 | 194 | if self._serialize:
|
134 |
| - logger = logging.getLogger(__name__) |
135 |
| - logger.info( |
136 |
| - "Serializing {} elements to byte tensors and concatenating them all ...".format( |
137 |
| - len(self._lst) |
138 |
| - ) |
| 195 | + serialize_method = ( |
| 196 | + serialize |
| 197 | + if isinstance(serialize, Callable) |
| 198 | + else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD |
139 | 199 | )
|
140 |
| - self._lst = [_serialize(x) for x in self._lst] |
141 |
| - self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) |
142 |
| - self._addr = np.cumsum(self._addr) |
143 |
| - self._lst = np.concatenate(self._lst) |
144 |
| - logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) |
| 200 | + logger.info(f"Serializing the dataset using: {serialize_method}") |
| 201 | + self._lst = serialize_method(self._lst) |
145 | 202 |
|
146 | 203 | def __len__(self):
|
147 |
| - if self._serialize: |
148 |
| - return len(self._addr) |
149 |
| - else: |
150 |
| - return len(self._lst) |
| 204 | + return len(self._lst) |
151 | 205 |
|
152 | 206 | def __getitem__(self, idx):
|
153 |
| - if self._serialize: |
154 |
| - start_addr = 0 if idx == 0 else self._addr[idx - 1].item() |
155 |
| - end_addr = self._addr[idx].item() |
156 |
| - bytes = memoryview(self._lst[start_addr:end_addr]) |
157 |
| - return pickle.loads(bytes) |
158 |
| - elif self._copy: |
| 207 | + if self._copy and not self._serialize: |
159 | 208 | return copy.deepcopy(self._lst[idx])
|
160 | 209 | else:
|
161 | 210 | return self._lst[idx]
|
|
0 commit comments