Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 46 additions & 48 deletions patchify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
"""

from typing import Tuple, Union, cast

import numpy as np
from .view_as_windows import view_as_windows


Imsize = Union[Tuple[int, int], Tuple[int, int, int]]


Expand All @@ -17,14 +15,16 @@ def patchify(image: np.ndarray, patch_size: Imsize, step: int = 1) -> np.ndarray

Parameters
----------
image: the image to be split. It can be 2d (m, n) or 3d (k, m, n)
image: the image to be split. It can be 2D (m, n) or 3D (k, m, n)
patch_size: the size of a single patch
step: the step size between patches

Examples
--------
>>> image = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
>>> patches = patchify(image, (2, 2), step=1) # split image into 2*3 small 2*2 patches.
>>> image = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]])
>>> patches = patchify(image, (2, 2), step=1) # split into 2×3 small 2×2 patches
>>> assert patches.shape == (2, 3, 2, 2)
>>> reconstructed_image = unpatchify(patches, image.shape)
>>> assert (reconstructed_image == image).all()
Expand All @@ -34,42 +34,43 @@ def patchify(image: np.ndarray, patch_size: Imsize, step: int = 1) -> np.ndarray

def unpatchify(patches: np.ndarray, imsize: Imsize) -> np.ndarray:
"""
Merge patches into the orignal image
Merge patches into the original image.

Parameters
----------
patches: the patches to merge. It can be patches for a 2d image (k, l, m, n)
or 3d volume (i, j, k, l, m, n)
patches: patches to merge. Can be patches for a 2D image (k, l, m, n)
or 3D volume (i, j, k, l, m, n)
imsize: the size of the original image or volume

Examples
--------
>>> image = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
>>> patches = patchify(image, (2, 2), step=1) # split image into 2*3 small 2*2 patches.
>>> image = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8],
... [9, 10, 11, 12]])
>>> patches = patchify(image, (2, 2), step=1) # split into 2×3 small 2×2 patches
>>> assert patches.shape == (2, 3, 2, 2)
>>> reconstructed_image = unpatchify(patches, image.shape)
>>> assert (reconstructed_image == image).all()
"""

assert len(patches.shape) / 2 == len(
assert len(patches.shape) // 2 == len(
imsize
), "The patches dimension is not equal to the original image size"
), "The patches dimension does not match the original image size"

if len(patches.shape) == 4:
return _unpatchify2d(patches, cast(Tuple[int, int], imsize))
elif len(patches.shape) == 6:
return _unpatchify3d(patches, cast(Tuple[int, int, int], imsize))
else:
raise NotImplementedError(
"Unpatchify only supports a matrix of 2D patches (k, l, m, n)"
"Unpatchify only supports matrices of 2D patches (k, l, m, n) "
f"or 3D volumes (i, j, k, l, m, n), but got: {patches.shape}"
)


def _unpatchify2d( # pylint: disable=too-many-locals
patches: np.ndarray, imsize: Tuple[int, int]
) -> np.ndarray:

def _unpatchify2d(patches: np.ndarray, imsize: Tuple[int, int]) -> np.ndarray:
"""
Reconstruct a 2D image from its patches.
"""
assert len(patches.shape) == 4

i_h, i_w = imsize
Expand All @@ -80,8 +81,8 @@ def _unpatchify2d( # pylint: disable=too-many-locals
s_w = 0 if n_w <= 1 else (i_w - p_w) / (n_w - 1)
s_h = 0 if n_h <= 1 else (i_h - p_h) / (n_h - 1)

# The step size should be same for all patches, otherwise the patches are unable
# to reconstruct into a image
# The step size should be the same for all patches; otherwise,
# the patches cannot reconstruct into a valid image
if int(s_w) != s_w:
raise NonUniformStepSizeError(i_w, n_w, p_w, s_w)
if int(s_h) != s_h:
Expand All @@ -90,31 +91,27 @@ def _unpatchify2d( # pylint: disable=too-many-locals
s_h = int(s_h)

i, j = 0, 0

while True:
i_o, j_o = i * s_h, j * s_w

image[i_o : i_o + p_h, j_o : j_o + p_w] = patches[i, j]
i_o, j_o = int(i * s_h), int(j * s_w)
image[int(i_o): int(i_o + p_h), int(j_o): int(j_o + p_w)] = patches[i, j]

if j < n_w - 1:
j = min((j_o + p_w) // s_w, n_w - 1)
j = min(int((j_o + p_w) // s_w), n_w - 1)
elif i < n_h - 1 and j >= n_w - 1:
# Go to next row
i = min((i_o + p_h) // s_h, n_h - 1)
i = min(int((i_o + p_h) // s_h), n_h - 1)
j = 0
elif i >= n_h - 1 and j >= n_w - 1:
# Finished
break
else:
raise RuntimeError("Unreachable")

return image


def _unpatchify3d( # pylint: disable=too-many-locals
patches: np.ndarray, imsize: Tuple[int, int, int]
) -> np.ndarray:

def _unpatchify3d(patches: np.ndarray, imsize: Tuple[int, int, int]) -> np.ndarray:
"""
Reconstruct a 3D volume from its patches.
"""
assert len(patches.shape) == 6

i_h, i_w, i_c = imsize
Expand All @@ -126,8 +123,8 @@ def _unpatchify3d( # pylint: disable=too-many-locals
s_h = 0 if n_h <= 1 else (i_h - p_h) / (n_h - 1)
s_c = 0 if n_c <= 1 else (i_c - p_c) / (n_c - 1)

# The step size should be same for all patches, otherwise the patches are unable
# to reconstruct into a image
# The step size should be the same for all patches; otherwise,
# the patches cannot reconstruct into a valid volume
if int(s_w) != s_w:
raise NonUniformStepSizeError(i_w, n_w, p_w, s_w)
if int(s_h) != s_h:
Expand All @@ -140,24 +137,20 @@ def _unpatchify3d( # pylint: disable=too-many-locals
s_c = int(s_c)

i, j, k = 0, 0, 0

while True:

i_o, j_o, k_o = i * s_h, j * s_w, k * s_c

image[i_o : i_o + p_h, j_o : j_o + p_w, k_o : k_o + p_c] = patches[i, j, k]
i_o, j_o, k_o = int(i * s_h), int(j * s_w), int(k * s_c)
image[int(i_o): int(i_o + p_h), int(j_o): int(j_o + p_w), int(k_o): int(k_o + p_c)] = patches[i, j, k]

if k < n_c - 1:
k = min((k_o + p_c) // s_c, n_c - 1)
k = min(int((k_o + p_c) // s_c), n_c - 1)
elif j < n_w - 1 and k >= n_c - 1:
j = min((j_o + p_w) // s_w, n_w - 1)
j = min(int((j_o + p_w) // s_w), n_w - 1)
k = 0
elif i < n_h - 1 and j >= n_w - 1 and k >= n_c - 1:
i = min((i_o + p_h) // s_h, n_h - 1)
i = min(int((i_o + p_h) // s_h), n_h - 1)
j = 0
k = 0
elif i >= n_h - 1 and j >= n_w - 1 and k >= n_c - 1:
# Finished
break
else:
raise RuntimeError("Unreachable")
Expand All @@ -166,18 +159,23 @@ def _unpatchify3d( # pylint: disable=too-many-locals


class NonUniformStepSizeError(RuntimeError):
def __init__(
self, imsize: int, n_patches: int, patch_size: int, step_size: float
) -> None:
"""
Raised when patch reconstruction requires a non-integer step size.
"""
def __init__(self, imsize: int, n_patches: int, patch_size: int, step_size: float) -> None:
super().__init__(imsize, n_patches, patch_size, step_size)
self.n_patches = n_patches
self.patch_size = patch_size
self.imsize = imsize
self.step_size = step_size

def __repr__(self) -> str:
return f"Unpatchify only supports reconstructing image with a uniform step size for all patches. \
However, reconstructing {self.n_patches} x {self.patch_size}px patches to an {self.imsize} image requires {self.step_size} as step size, which is not an integer."
return (
f"Unpatchify only supports reconstructing with a uniform step size for all patches. "
f"However, reconstructing {self.n_patches} × {self.patch_size} patches "
f"to an image/volume of size {self.imsize} requires a step size of {self.step_size}, "
"which is not an integer."
)

def __str__(self) -> str:
return self.__repr__()