diff --git a/streamlit_image_comparison/__init__.py b/streamlit_image_comparison/__init__.py
index ee6b9ee..dbcbe09 100644
--- a/streamlit_image_comparison/__init__.py
+++ b/streamlit_image_comparison/__init__.py
@@ -12,210 +12,218 @@
__version__ = "0.0.5"
+
def exif_transpose(image: Image.Image):
- """
- Transpose a PIL image accordingly if it has an EXIF Orientation tag.
- Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
- :param image: The image to transpose.
- :return: An image.
- """
- exif = image.getexif()
- orientation = exif.get(0x0112, 1) # default 1
- if orientation > 1:
- method = {
- 2: Image.FLIP_LEFT_RIGHT,
- 3: Image.ROTATE_180,
- 4: Image.FLIP_TOP_BOTTOM,
- 5: Image.TRANSPOSE,
- 6: Image.ROTATE_270,
- 7: Image.TRANSVERSE,
- 8: Image.ROTATE_90,
- }.get(orientation)
- if method is not None:
- image = image.transpose(method)
- del exif[0x0112]
- image.info["exif"] = exif.tobytes()
- return image
-
-def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False):
- """
- Loads an image as PIL.Image.Image.
- Args:
- image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
- """
- # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
- Image.MAX_IMAGE_PIXELS = None
-
- if isinstance(image, Image.Image):
- image_pil = image.convert('RGB')
- elif isinstance(image, str):
- # read image if str image path is provided
- try:
- image_pil = Image.open(
- requests.get(image, stream=True).raw if str(image).startswith("http") else image
- ).convert("RGB")
- if exif_fix:
- image_pil = exif_transpose(image_pil)
- except: # handle large/tiff image reading
- try:
- import skimage.io
- except ImportError:
- raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
- image_sk = skimage.io.imread(image).astype(np.uint8)
- if len(image_sk.shape) == 2: # b&w
- image_pil = Image.fromarray(image_sk, mode="1").convert("RGB")
- elif image_sk.shape[2] == 4: # rgba
- image_pil = Image.fromarray(image_sk, mode="RGBA").convert("RGB")
- elif image_sk.shape[2] == 3: # rgb
- image_pil = Image.fromarray(image_sk, mode="RGB")
- else:
- raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
- elif isinstance(image, np.ndarray):
- if image.shape[0] < 5: # image in CHW
- image = image[:, :, ::-1]
- image_pil = Image.fromarray(image).convert("RGB")
- else:
- raise TypeError("read image with 'pillow' using 'Image.open()'")
-
- return image_pil
+ """
+ Transpose a PIL image accordingly if it has an EXIF Orientation tag.
+ Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
+ :param image: The image to transpose.
+ :return: An image.
+ """
+ exif = image.getexif()
+ orientation = exif.get(0x0112, 1) # default 1
+ if orientation > 1:
+ method = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90,
+ }.get(orientation)
+ if method is not None:
+ image = image.transpose(method)
+ del exif[0x0112]
+ image.info["exif"] = exif.tobytes()
+ return image
+
+
+def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False, img_type: str = "RGBA"):
+ """
+ Loads an image as PIL.Image.Image.
+ Args:
+ image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
+ """
+ # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
+ Image.MAX_IMAGE_PIXELS = None
+
+ if isinstance(image, Image.Image):
+ image_pil = image.convert(img_type)
+ elif isinstance(image, str):
+ # read image if str image path is provided
+ try:
+ image_pil = Image.open(
+ requests.get(image, stream=True).raw if str(image).startswith("http") else image
+ ).convert(img_type)
+ if exif_fix:
+ image_pil = exif_transpose(image_pil)
+ except: # handle large/tiff image reading
+ try:
+ import skimage.io
+ except ImportError:
+ raise ImportError("Please run 'pip install -U scikit-image imagecodecs' for large image handling.")
+ image_sk = skimage.io.imread(image).astype(np.uint8)
+ if len(image_sk.shape) == 2: # b&w
+ image_pil = Image.fromarray(image_sk, mode="1").convert(img_type)
+ elif image_sk.shape[2] == 4: # rgba
+ image_pil = Image.fromarray(image_sk, mode=img_type)
+ elif image_sk.shape[2] == 3: # rgb
+ image_pil = Image.fromarray(image_sk, mode=img_type)
+ else:
+ raise TypeError(f"image with shape: {image_sk.shape[3]} is not supported.")
+ elif isinstance(image, np.ndarray):
+ if image.shape[0] < 5: # image in CHW
+ image = image[:, :, ::-1]
+ image_pil = Image.fromarray(image).convert(img_type)
+ else:
+ raise TypeError("read image with 'pillow' using 'Image.open()'")
+
+ return image_pil
+
def pillow_to_base64(image: Image.Image) -> str:
- """
- Convert a PIL image to a base64-encoded string.
-
- Parameters
- ----------
- image: PIL.Image.Image
- The image to be converted.
-
- Returns
- -------
- str
- The base64-encoded string.
- """
- in_mem_file = io.BytesIO()
- image.save(in_mem_file, format="JPEG", subsampling=0, quality=100)
- img_bytes = in_mem_file.getvalue() # bytes
- image_str = base64.b64encode(img_bytes).decode("utf-8")
- base64_src = f"data:image/jpg;base64,{image_str}"
- return base64_src
+ """
+ Convert a PIL image to a base64-encoded string.
+
+ Parameters
+ ----------
+ image: PIL.Image.Image
+ The image to be converted.
+
+ Returns
+ -------
+ str
+ The base64-encoded string.
+ """
+ in_mem_file = io.BytesIO()
+ image.save(in_mem_file, format="JPEG", subsampling=0, quality=100)
+ img_bytes = in_mem_file.getvalue() # bytes
+ image_str = base64.b64encode(img_bytes).decode("utf-8")
+ base64_src = f"data:image/jpg;base64,{image_str}"
+ return base64_src
+
def local_file_to_base64(image_path: str) -> str:
- """
- Convert a local image file to a base64-encoded string.
-
- Parameters
- ----------
- image_path: str
- The path to the image file.
-
- Returns
- -------
- str
- The base64-encoded string.
- """
- file_ = open(image_path, "rb")
- img_bytes = file_.read()
- image_str = base64.b64encode(img_bytes).decode("utf-8")
- file_.close()
- base64_src = f"data:image/jpg;base64,{image_str}"
- return base64_src
+ """
+ Convert a local image file to a base64-encoded string.
+
+ Parameters
+ ----------
+ image_path: str
+ The path to the image file.
+
+ Returns
+ -------
+ str
+ The base64-encoded string.
+ """
+ file_ = open(image_path, "rb")
+ img_bytes = file_.read()
+ image_str = base64.b64encode(img_bytes).decode("utf-8")
+ file_.close()
+ base64_src = f"data:image/jpg;base64,{image_str}"
+ return base64_src
+
def pillow_local_file_to_base64(image: Image.Image, temp_dir: str):
- """
- Convert a Pillow image to a base64 string, using a temporary file on disk.
+ """
+ Convert a Pillow image to a base64 string, using a temporary file on disk.
+
+ Parameters
+ ----------
+ image : PIL.Image.Image
+ The Pillow image to convert.
+ temp_dir : str
+ The directory to use for the temporary file.
- Parameters
- ----------
- image : PIL.Image.Image
- The Pillow image to convert.
- temp_dir : str
- The directory to use for the temporary file.
+ Returns
+ -------
+ str
+ A base64-encoded string representing the image.
+ """
+ # Create temporary file path using os.path.join()
+ img_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".png")
- Returns
- -------
- str
- A base64-encoded string representing the image.
- """
- # Create temporary file path using os.path.join()
- img_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".jpg")
+ # Save image to temporary file
+ image.save(img_path, subsampling=0, quality=100)
- # Save image to temporary file
- image.save(img_path, subsampling=0, quality=100)
+ # Convert temporary file to base64 string
+ base64_src = local_file_to_base64(img_path)
- # Convert temporary file to base64 string
- base64_src = local_file_to_base64(img_path)
+ return base64_src
- return base64_src
def image_comparison(
- img1: Union[Image.Image, str, np.ndarray],
- img2: Union[Image.Image, str, np.ndarray],
- label1: str = "1",
- label2: str = "2",
- width: int = 704,
- show_labels: bool = True,
- starting_position: int = 50,
- make_responsive: bool = True,
- in_memory: bool = False,
+ img1: Union[Image.Image, str, np.ndarray],
+ img2: Union[Image.Image, str, np.ndarray],
+ label1: str = "1",
+ label2: str = "2",
+ width: int = 704,
+ show_labels: bool = True,
+ img_type="RGBA",
+ starting_position: int = 50,
+ make_responsive: bool = True,
+ in_memory: bool = False,
) -> components.html:
- """
- Create a comparison slider for two images.
-
- Parameters
- ----------
- img1: str, PIL Image, or numpy array
- Data for the first image.
- img2: str, PIL Image, or numpy array
- Data for the second image.
- label1: str, optional
- Label for the first image. Default is "1".
- label2: str, optional
- Label for the second image. Default is "2".
- width: int, optional
- Width of the component in pixels. Default is 704.
- show_labels: bool, optional
- Whether to show labels on the images. Default is True.
- starting_position: int, optional
- Starting position of the slider as a percentage (0-100). Default is 50.
- make_responsive: bool, optional
- Whether to enable responsive mode. Default is True.
- in_memory: bool, optional
- Whether to handle pillow to base64 conversion in memory without saving to local. Default is False.
-
- Returns
- -------
- components.html
- Returns a static component with a timeline
- """
- # Prepare images
- img1_pillow = read_image_as_pil(img1)
- img2_pillow = read_image_as_pil(img2)
-
- img_width, img_height = img1_pillow.size
- h_to_w = img_height / img_width
- height = int((width * h_to_w) * 0.95)
-
- if in_memory:
- # Convert images to base64 strings
- img1 = pillow_to_base64(img1_pillow)
- img2 = pillow_to_base64(img2_pillow)
- else:
- # Create base64 strings from temporary files
- os.makedirs(TEMP_DIR, exist_ok=True)
- for file_ in os.listdir(TEMP_DIR):
- if file_.endswith(".jpg"):
- os.remove(os.path.join(TEMP_DIR, file_))
- img1 = pillow_local_file_to_base64(img1_pillow, TEMP_DIR)
- img2 = pillow_local_file_to_base64(img2_pillow, TEMP_DIR)
-
- # Load CSS and JS
- cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
- css_block = f''
- js_block = f''
-
- # write html block
- htmlcode = f"""
+ """
+ Create a comparison slider for two images.
+
+ Parameters
+ ----------
+ img1: str, PIL Image, or numpy array
+ Data for the first image.
+ img2: str, PIL Image, or numpy array
+ Data for the second image.
+ label1: str, optional
+ Label for the first image. Default is "1".
+ label2: str, optional
+ Label for the second image. Default is "2".
+ width: int, optional
+ Width of the component in pixels. Default is 704.
+ show_labels: bool, optional
+ Whether to show labels on the images. Default is True.
+ img_type: color type
+ starting_position: int, optional
+ Starting position of the slider as a percentage (0-100). Default is 50.
+ make_responsive: bool, optional
+ Whether to enable responsive mode. Default is True.
+ in_memory: bool, optional
+ Whether to handle pillow to base64 conversion in memory without saving to local. Default is False.
+
+ Returns
+ -------
+ components.html
+ Returns a static component with a timeline
+ """
+ # Prepare images
+ img1_pillow = read_image_as_pil(img1, img_type=img_type)
+ img2_pillow = read_image_as_pil(img2, img_type=img_type)
+
+ img_width, img_height = img1_pillow.size
+ h_to_w = img_height / img_width
+ height = int((width * h_to_w) * 0.95)
+
+ if in_memory:
+ # Convert images to base64 strings
+ img1 = pillow_to_base64(img1_pillow)
+ img2 = pillow_to_base64(img2_pillow)
+ else:
+ # Create base64 strings from temporary files
+ os.makedirs(TEMP_DIR, exist_ok=True)
+ for file_ in os.listdir(TEMP_DIR):
+ if file_.endswith(".jpg"):
+ os.remove(os.path.join(TEMP_DIR, file_))
+ img1 = pillow_local_file_to_base64(img1_pillow, TEMP_DIR)
+ img2 = pillow_local_file_to_base64(img2_pillow, TEMP_DIR)
+
+ # Load CSS and JS
+ cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
+ css_block = f''
+ js_block = f''
+
+ # write html block
+ htmlcode = f"""
{css_block}
{js_block}
@@ -241,6 +249,6 @@ def image_comparison(
}});
"""
- static_component = components.html(htmlcode, height=height, width=width)
+ static_component = components.html(htmlcode, height=height, width=width)
- return static_component
+ return static_component