[compatibility] Fix torchvision transforms NotImplementedError #104
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The following error happens when using DEIM on torchvision > 0.20.0. The root cause is version incompatibility. In pytorch/vision#8787 (included in v0.21.0), torchvision made the overridable v2 transform function public, changing the name from _transform() to transform(). There is no impact on functionality otherwise.
This issue is also described in #33 , posting here and opening this PR to fix the updated version in the same spirit as #47 . I had no problems after this fix. Alternatively, @ShihuaHuang95 if this fix causes problems, maybe just pin torchvision on 0.20.0? Looking forward to DEIMv2!
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/ec2-user/DEIM/train.py", line 84, in
[rank0]: main(args)
[rank0]: File "/home/ec2-user/DEIM/train.py", line 54, in main
[rank0]: solver.fit()
[rank0]: File "/home/ec2-user/DEIM/engine/solver/det_solver.py", line 76, in fit
[rank0]: train_stats = train_one_epoch(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/DEIM/engine/solver/det_engine.py", line 42, in train_one_epoch
[rank0]: for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
[rank0]: File "/home/ec2-user/DEIM/engine/misc/logger.py", line 215, in log_every
[rank0]: for obj in iterable:
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 733, in next
[rank0]: data = self._next_data()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1515, in _next_data
[rank0]: return self._process_data(data, worker_id)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1550, in _process_data
[rank0]: data.reraise()
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/_utils.py", line 750, in reraise
[rank0]: raise exception
[rank0]: NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
[rank0]: Original Traceback (most recent call last):
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
[rank0]: data = fetcher.fetch(index) # type: ignore[possibly-undefined]
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
[rank0]: data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 52, in
[rank0]: data = [self.dataset[idx] for idx in possibly_batched_index]
[rank0]: ~~~~~~~~~~~~^^^^^
[rank0]: File "/home/ec2-user/DEIM/engine/data/dataset/coco_dataset.py", line 44, in getitem
[rank0]: img, target, _ = self._transforms(img, target, self)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/DEIM/engine/data/transforms/container.py", line 58, in forward
[rank0]: return self.get_forward(self.policy['name'])(*inputs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/DEIM/engine/data/transforms/container.py", line 100, in stop_epoch_forward
[rank0]: sample = transform(sample)
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 68, in forward
[rank0]: flat_outputs = [
[rank0]: ^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 69, in
[rank0]: self.transform(inpt, params) if needs_transform else inpt
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/anaconda3/envs/deim/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 55, in transform
[rank0]: raise NotImplementedError
[rank0]: NotImplementedError