Hello!
When all patches in a batch are filtered out, pathopatch/patch_extraction/dataset.py: LivePatchWSIDataloader.next leaves patches as a list (len(patches)==0 case), but the function is annotated to return a Tensor. This leads to type errors (and sometimes patches[0] IndexError).
|
def __next__(self) -> Tuple[torch.Tensor, List[dict], List[np.ndarray]]: |
|
"""Create one batch of patches |
|
|
|
Raises: |
|
StopIteration: If the end of the dataset is reached. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, List[dict], List[np.ndarray]]: |
|
* torch.Tensor: Batch of patches, shape (batch_size, 3, patch_size, patch_size) |
|
* List[dict]: List of metadata for each patch |
|
* List[np.ndarray]: List of masks for each patch |
|
""" |
|
patches = [] |
|
metadata = [] |
|
masks = [] |
|
if self.i < len(self.element_list): |
|
batch_item_count = 0 |
|
while batch_item_count < self.batch_size and self.i < len( |
|
self.element_list |
|
): |
|
patch, meta, mask = self.dataset[self.element_list[self.i]] |
|
self.i += 1 |
|
if patch is None and meta["discard_patch"]: |
|
self.discard_count += 1 |
|
continue |
|
elif self.dataset.config.filter_patches: |
|
output = self.dataset.detector_model( |
|
self.dataset.detector_transforms(patch)[None, ...] |
|
) |
|
output_prob = torch.softmax(output, dim=-1) |
|
prediction = torch.argmax(output_prob, dim=-1) |
|
if int(prediction) != 0: |
|
self.discard_count += 1 |
|
continue |
|
patches.append(patch) |
|
metadata.append(meta) |
|
masks.append(mask) |
|
batch_item_count += 1 |
|
if len(patches) > 1: |
|
patches = [torch.tensor(f) for f in patches] |
|
patches = torch.stack(patches) |
|
elif len(patches) == 1: |
|
patches = torch.tensor(patches[0][None, ...]) |
|
return patches, metadata, masks |
|
else: |
|
raise StopIteration |
Suggested fix
def __next__(self):
if len(patches) > 1:
patches = [torch.tensor(f) for f in patches]
patches = torch.stack(patches)
elif len(patches) == 1:
patches = torch.tensor(patches[0][None, ...])
+ elif len(patches) == 0:
+ raise StopIteration
return patches, metadata, masks
I would be happy to open a PR for this. Do let me know if this is your preferred fix.
Hello!
When all patches in a batch are filtered out, pathopatch/patch_extraction/dataset.py: LivePatchWSIDataloader.next leaves patches as a list (len(patches)==0 case), but the function is annotated to return a Tensor. This leads to type errors (and sometimes patches[0] IndexError).
PathoPatcher/pathopatch/patch_extraction/dataset.py
Lines 782 to 827 in 9ef1f12
Suggested fix
def __next__(self): if len(patches) > 1: patches = [torch.tensor(f) for f in patches] patches = torch.stack(patches) elif len(patches) == 1: patches = torch.tensor(patches[0][None, ...]) + elif len(patches) == 0: + raise StopIteration return patches, metadata, masksI would be happy to open a PR for this. Do let me know if this is your preferred fix.