diff --git a/00_index.ipynb b/00_index.ipynb index 13143226..1705e4ee 100644 --- a/00_index.ipynb +++ b/00_index.ipynb @@ -32,6 +32,10 @@ "- [SciPy](./23_library_scipy.ipynb)\n", "- [Pandas](./24_library_pandas.ipynb)\n", "\n", + "# Hands-On Projects\n", + "\n", + "- [Image Classification](./31_image_classification.ipynb)\n", + "\n", "# Additional Topics\n", "\n", "- [Parallelism and concurrency in Python](./14_threads.ipynb)\n" diff --git a/31_image_classification.ipynb b/31_image_classification.ipynb new file mode 100644 index 00000000..bcddacec --- /dev/null +++ b/31_image_classification.ipynb @@ -0,0 +1,2021 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Image Classification Notebook" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "# Table of Contents\n", + " - [Image Classification Notebook](#Image-Classification-Notebook)\n", + " - [References](#References)\n", + " - [Libraries](#Libraries)\n", + " - [Introduction](#Introduction)\n", + " - [Classes](#Classes)\n", + " - [Functions](#Functions)\n", + " - [Dataset](#Dataset)\n", + " - [Load data](#Load-data)\n", + " - [Explore image processing](#Explore-image-processing)\n", + " - [Example image](#Example-image)\n", + " - [Geometric transformation](#Geometric-transformation)\n", + " - [Scaling](#Scaling)\n", + " - [Cropping](#Cropping)\n", + " - [Horizontal Flip](#Horizontal-Flip)\n", + " - [Vertical Flip](#Vertical-Flip)\n", + " - [Rotation](#Rotation)\n", + " - [Image filtering](#Image-filtering)\n", + " - [Average filter ](#Average-filter)\n", + " - [Median filter](#Median-filter)\n", + " - [Gaussian filter](#Gaussian-filter)\n", + " - [Photometric transformation](#Photometric-transformation)\n", + " - [Adjust brightness](#Adjust-brightness)\n", + " - [Adjust contrast](#Adjust-contrast)\n", + " - [Adjust saturation](#Adjust-saturation)\n", + " - [Image classifier development using CNNs](#Image-classifier-development-using-CNNs)\n", + " - [Dataset preprocessing](#Dataset-preprocessing)\n", + " - [Train, validation, and test sets](#Train,-validation,-and-test-sets)\n", + " - [Data Augmentation](#Data-Augmentation)\n", + " - [PyTorch Datasets](#PyTorch-Datasets)\n", + " - [PyTorch Dataloaders](#PyTorch-Dataloaders)\n", + " - [Model training](#Model-training)\n", + " - [Model Training Overview](#Model-Training-Overview)\n", + " - [Check which device is used for training](#Check-which-device-is-used-for-training)\n", + " - [Define training hyperparameters](#Define-training-hyperparameters)\n", + " - [Loss function](#Loss-function)\n", + " - [Initialise model architecture](#Initialise-model-architecture)\n", + " - [Optimiser function](#Optimiser-function)\n", + " - [Train model](#Train-model)\n", + " - [Learning curves](#Learning-curves)\n", + " - [Model testing](#Model-testing)\n", + " - [Explore results](#Explore-results)\n", + " - [Compute average accuracy](#Compute-average-accuracy)\n", + " - [Compute confusion matrix](#Compute-confusion-matrix)\n", + " - [Explain image classifier predictions](#Explain-image-classifier-predictions)\n", + " - [Prepare image for Grad-CAM](#Prepare-image-for-Grad-CAM)\n", + " - [Compute GradCAM heatmap](#Compute-GradCAM-heatmap)\n", + " - [Visualise Grad-CAM heatmap with the image](#Visualise-Grad-CAM-heatmap-with-the-image)" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "Here are some additional references to guide you while self-learning:\n", + "- Official documentation for [openCV](https://docs.opencv.org/4.x/d6/d00/tutorial_py_root.html).\n", + "- Official documentation for [PIL library](https://pillow.readthedocs.io/en/stable/).\n", + "- Official documentation for [PyTorch](https://pytorch.org/).\n", + "- Official documentation for [Albumentations](https://albumentations.ai/).\n", + "- Official documentation for [PyTorch GradCAM](https://jacobgil.github.io/pytorch-gradcam-book/introduction.html).\n", + "- [A tutorial from Microsoft to compute image classification using PyTorch](https://learn.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-train-model)." + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Libraries" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "- [Matplotlib](./20_library_matplotlib.ipynb)\n", + "- [NumPy](./21_library_numpy.ipynb)\n", + "- [scikit-learn](./22_library_sklearn.ipynb)\n", + "- OpenCV-Python\n", + "- PyTorch\n", + "- Albumentations\n", + "- PyTorch Grad-CAM" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "Image Classification is a foundational task in the field of computer vision and machine learning.\n", + "This notebook aims to provide practical experience in image processing and in building and evaluating image classification models. \n", + "\n", + "It begins by demonstrating how to load and preprocess image data using Matplotlib and OpenCV-Python.\n", + "Then, it shows how to build a basic image classification pipeline based on Convolutional Neural Networks (CNNs) using PyTorch, Albumentations, and Scikit-learn.\n", + "Next, it covers how to evaluate model performance using Scikit-learn and NumPy, and finally, it introduces model explainability using Grad-CAM.\n", + "\n", + "The goal of this notebook is not to teach the underlying algorithms and procedures used in this field, but rather to give the user an idea of what can be done with these Python libraries." + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Classes\n", + "\n", + "The following three classes are essential for improving modularity and readability.\n", + "\n", + "- **ImageDataset** is used to load images along with their labels and to perform image augmentation.\n", + "- **ImageClassifier** is responsible for building the image classification model, which in this case is based on Convolutional Neural Networks (CNNs).\n", + "- **Trainer** handles the training and evaluation processes using batches of data.\n", + "\n", + "By organizing the code in this way, we simplify debugging and future extensions." + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "The classes are currently not complete. Use the following code to prepare them:\n", + "\n", + "```ImageDataset```: \n", + "\n", + "In ```__init__```, initialise the following attributes:\n", + "```python\n", + "self.images = images # Input images\n", + "self.labels = labels # Output classes\n", + "self.transform = transform # Transformations applied to the data when calling them\n", + "```\n", + "\n", + "Complete function ```__len__``` - this method is needed to let the generator know how many samples there are in the data:\n", + "```python\n", + "return len(self.images)\n", + "```\n", + "\n", + "Complete function ```__getitem__``` - this method is needed to lety the generator know what to do to samples when calling them:\n", + "```python\n", + "image = self.images[idx]\n", + "label = self.labels[idx]\n", + "\n", + "# Ensure the image is in the shape (H, W, C) for Albumentations library (library used for image augmentation)\n", + "image = np.transpose(image, (1, 2, 0))\n", + "\n", + "# Apply transformations on the images\n", + "if self.transform:\n", + " augmented = self.transform(image=image)\n", + " image = augmented['image']\n", + "\n", + "return image, label\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "import numpy as np\n", + "\n", + "class ImageDataset(Dataset):\n", + " def __init__(self, images, labels, transform=None):\n", + " pass\n", + " \n", + " def __len__(self):\n", + " return\n", + " \n", + " def __getitem__(self, idx):\n", + " return" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "```ImageClassifier```: \n", + "\n", + "`__init__` function:\n", + "\n", + "The first thing to do is to build the `__init__` function, which contains the variables needed for building the neural network.\n", + "Let's start by defining the number of feature maps in the first convolutional layer (the value is empirical):\n", + "\n", + "```python\n", + "self.feature_maps = 64\n", + "```\n", + "\n", + "To help a computer understand and classify images, we build a model made up of layers, kind of like stacking Lego blocks.\n", + "Each block does a specific task — detecting patterns, reducing size, or making decisions. Here's what each component does:\n", + "\n", + "```python\n", + "self.conv1 = nn.Conv2d(in_channels, self.feature_maps, kernel_size = 3)\n", + "```\n", + "\n", + "This layer scans the image for small patterns (like edges or colors).\n", + "`in_channels` is the number of input image channels (e.g. 3 for RGB images).\n", + "`self.feature_maps` is how many different patterns we want the model to learn at this layer.\n", + "`kernel_size = 3` means the scanning window is 3x3 pixels. The value is empirical.\n", + "\n", + "```python\n", + "self.pool1 = nn.MaxPool2d(kernel_size = 2)\n", + "```\n", + "This layer shrinks the size of the image while keeping the most important info (max values).\n", + "It helps the model focus and reduces computation.\n", + "\n", + "```python\n", + "self.bn1 = nn.BatchNorm2d(self.feature_maps)\n", + "```\n", + "\n", + "This layer normalizes the outputs, making training faster and more stable.\n", + "The combination of the foreamentioned layers is also usually called as convolutional block.\n", + "After defining the first convolutional block, lets define the second one:\n", + "\n", + "```python\n", + "self.conv2 = nn.Conv2d(self.feature_maps, self.feature_maps * 2, kernel_size = 3)\n", + "self.pool2 = nn.MaxPool2d(kernel_size = 2)\n", + "self.bn2 = nn.BatchNorm2d(self.feature_maps * 2)\n", + "```\n", + "The second block is very similar to the first block, but now it looks for more complex patterns by increasing the number of feature maps (i.e. learning more features).\n", + "After defining the second convolutional block.\n", + "Lets define the third and last one:\n", + "\n", + "```python\n", + "self.conv3 = nn.Conv2d(self.feature_maps * 2, self.feature_maps * 4, kernel_size = 3)\n", + "self.pool3 = nn.MaxPool2d(kernel_size = 2)\n", + "self.bn3 = nn.BatchNorm2d(self.feature_maps * 4)\n", + "```\n", + "\n", + "This block explored even deeper patterns, such as shapes or textures.\n", + "As we go deeper, the network becomes better at understanding the image.\n", + "Then, we define the activation layer that is going to be used in-between these blocks:\n", + "\n", + "```python\n", + "self.relu = nn.ReLU()\n", + "```\n", + "\n", + "After each layer, we add a \"yes/no\" switch to keep only useful patterns.\n", + "ReLU (Rectified Linear Unit) sets negative values to zero — it adds non-linearity to help the network learn more complex things.\n", + "Next, we define the layer that transforms the data from 2D images into an 1D vector (like stretching out a grid of pixels into a line):\n", + "\n", + "```python\n", + "self.flatten = nn.Flatten(start_dim=1)\n", + "```\n", + "\n", + "Now, we define the dropout layer:\n", + "\n", + "```python\n", + "self.dropout = nn.Dropout(p = 0.3)\n", + "```\n", + "This layer randomly turns off a pre-define percentage of neurons (`p = 0.3`) during training to prevent overfitting — so the model does not memorize the training data too closely.\n", + "Finally, we define the classifier:\n", + "\n", + "```python\n", + "self.out_classes = out_classes\n", + "self.fc = nn.Linear(1024, self.out_classes)\n", + "```\n", + "\n", + "This final layer is like the decision-maker.\n", + "It takes all the features the model has learned and decides which class (e.g. cat, dog, airplane) the input image belongs to.\n", + "\n", + "1024 is the number of features coming into the layer (depends on the hyperparameters used in the previous layers), and `out_classes` is how many classes we want to predict." + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "`forward` function:\n", + "\n", + "After defining the function `__init__`, we need to define the function `forward`.\n", + "This one is responsible to combine all the layers defined in the `__init__` to build the neural network model.\n", + "Basically, it describes how an input image flows through the network, one layer at a time, to become a prediction.\n", + "\n", + "```python\n", + "# Convolutional block 1\n", + "x = self.conv1(x)\n", + "x = self.pool1(x)\n", + "x = self.relu(x)\n", + "x = self.bn1(x)\n", + "\n", + "# Convolutional block 2\n", + "x = self.conv2(x)\n", + "x = self.pool2(x)\n", + "x = self.relu(x)\n", + "x = self.bn2(x)\n", + "\n", + "# Convolutional block 3\n", + "x = self.conv3(x)\n", + "x = self.pool3(x)\n", + "x = self.relu(x)\n", + "x = self.bn3(x)\n", + "\n", + "# Classifier\n", + "x = self.flatten(x)\n", + "x = self.dropout(x)\n", + "x = self.fc(x)\n", + "return x\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class ImageClassifier(nn.Module):\n", + " def __init__(self, in_channels = 1, out_classes = 1):\n", + " super(ImageClassifier, self).__init__()\n", + " \n", + " def forward(self, x):\n", + " return" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": {}, + "source": [ + "```Trainer```: \n", + "\n", + "`__init__` function:\n", + "\n", + "This function is used to initialise variables used in the other functions of the class.\n", + "Start by initialising the following attributes:\n", + "\n", + "```python\n", + "self.model = model\n", + "self.train_losses = []\n", + "self.val_losses = []\n", + "self.best_model_weights = None\n", + "```\n", + "\n", + "`self.model` – this is the neural network we're training.\n", + "`self.train_losses` and `self.val_losses` – these lists keep track of how well the model is doing on the training and validation sets over time (used to plot learning curves).\n", + "\n", + "`self.best_model_weights` – this will store a copy of the model when it performed best on the validation set (used for early stopping)." + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "`fit` function:\n", + "\n", + "This function goes through the data multiple times (epochs) to optimize the model’s performance.\n", + "It also applies early stopping, which stops training if performance stops improving.\n", + "Lets start by initialising the following variables:\n", + "\n", + "```python\n", + "early_stopping_count = 0\n", + "best_val_loss = 9999\n", + "best_epoch = 0\n", + "```\n", + "\n", + "`early_stopping_count` is used to track the number of epochs without improving validation loss (used in early stopping).\n", + "`best_val_loss`: is used to track the best validation loss ever seen. Here we use a very large meaning-less number because validation loss for classification is always smaller than that.\n", + "`best_epoch`: is used to track the epoch that got the best validation loss.\n", + "\n", + "Then, we initialise the file that is going to store the training statistics:\n", + "\n", + "```python\n", + "# Training log file\n", + "log_filename = \"training_log.txt\"\n", + "with open(log_filename, \"w\") as log_file:\n", + " log_file.write(\"Epoch,Train Loss,Val Loss,Best Val Loss,Best Epoch\\n\")\n", + "```\n", + "\n", + "Now comes the training phase.\n", + "It includes a main for-loop that runs until the end of the pre-defined number of epochs, and two inner loops: one for optimising the model's weights and another for evaluating the model after each epoch.\n", + "\n", + "```python\n", + "for epoch in range(epochs):\n", + " # Set the model to training mode. This is important because some layers behave differently during training than they do during evaluation.\n", + " self.model.train()\n", + " \n", + " # Loop over the training set\n", + " train_loss = 0.0\n", + " train_samples_count = 0.0\n", + " for i, data in enumerate(train_dataloader):\n", + " # Get the data and send it to the training device\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.long().to(device)\n", + " \n", + " # Clear old gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # Perform the forward step to get the predictions for the inputs\n", + " outputs = self.model(inputs)\n", + "\n", + " # Compute the loss of the predictions\n", + " loss = criterion(outputs, labels)\n", + "\n", + " # Perform the backward step which is responsible for computing the gradients\n", + " loss.backward()\n", + " \n", + " # Update the model weights using the new gradients\n", + " optimizer.step()\n", + "\n", + " # Save losses and number of samples in the batch\n", + " train_loss += loss.item()\n", + " train_samples_count += 1\n", + " \n", + " # Set the model to evaluation mode\n", + " self.model.eval()\n", + " \n", + " # Loop over the validation set. Here we just want to evaluate the model. Therefore, there is no weight optimisation.\n", + " val_loss = 0.0\n", + " val_samples_count = 0.0\n", + " for i, data in enumerate(val_dataloader):\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.long().to(device)\n", + " \n", + " outputs = self.model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " \n", + " val_loss += loss.item()\n", + " val_samples_count += 1\n", + " \n", + " # Divide the total train and validation losses by the number of samples, respectively.\n", + " train_loss /= train_samples_count\n", + " val_loss /= val_samples_count\n", + " \n", + " # Average training and validation losses for the epoch are stored.\n", + " self.train_losses.append(train_loss)\n", + " self.val_losses.append(val_loss)\n", + " \n", + " # Increase early stopping count\n", + " early_stopping_count += 1\n", + " \n", + " # In case the new validation loss is better than the best seen, \n", + " # save the current epoch index, new validation loss, current model \n", + " # weights and reset early stopping counter.\n", + " if val_loss < best_val_loss:\n", + " best_epoch = epoch\n", + " best_val_loss = val_loss\n", + " early_stopping_count = 0\n", + " self.model.best_model_weights = self.model.state_dict()\n", + " \n", + " print(f'Epoch: {epoch}, Loss: {train_loss}, Val Loss: {val_loss}. The best val loss is {best_val_loss} in epoch {best_epoch}.')\n", + " \n", + " # Append the current epoch statistics to the training log file\n", + " with open(log_filename, \"a\") as log_file:\n", + " log_file.write(f\"{epoch},{train_loss},{val_loss},{best_val_loss},{best_epoch}\\n\")\n", + " \n", + " # In case, the number of epochs without improving the validation loss \n", + " # gets above the pre-defined threshold, stop the training early to avoid overfitting.\n", + " if early_stopping_count == early_stopping_limit and early_stopping_limit > 0:\n", + " break\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "`predict` function:\n", + "\n", + "Once training is done, this method is used to predict labels for new data.\n", + "As early stopping is used during the training, it might be the case that the last models weights were not the best ones. Therefore, load the best-performing ones.\n", + "\n", + "```python\n", + "# Load best weights\n", + "if self.best_model_weights:\n", + " self.model.load_state_dict(self.best_model_weights)\n", + "```\n", + "\n", + "Set the model to evaluation mode\n", + "```python\n", + "# Test mode\n", + "self.model.eval()\n", + "```\n", + "\n", + "Loop through the test set to get the model predictions.\n", + "Not only the predictions, but also the original images and the true labels are stored for future use.\n", + "\n", + "```python\n", + "original_images = []\n", + "true_labels = []\n", + "predicted_labels = []\n", + "\n", + "for data in test_dataloader:\n", + " # Load data and send it to device\n", + " images, labels = data\n", + " images = images.to(device)\n", + "\n", + " # Get model predictions\n", + " outputs = self.model(images)\n", + "\n", + " # As the model outputs a vector scores (one per class), take \n", + " # the index of the maximum score which corresponds to the predicted class.\n", + " _, predicted = torch.max(outputs, 1)\n", + " \n", + " # .cpu() ensures that the data is on CPU and .numpy() convert it to a NumPy array\n", + " images = images.cpu().numpy()\n", + " labels = labels.numpy()\n", + " predicted = predicted.cpu().numpy()\n", + " \n", + " original_images.append(images)\n", + " true_labels.append(labels)\n", + " predicted_labels.append(predicted)\n", + "\n", + "# Convert the list of NumPy arrays into only one NumPy array\n", + "original_images = np.concatenate(original_images)\n", + "true_labels = np.concatenate(true_labels)\n", + "predicted_labels = np.concatenate(predicted_labels)\n", + "\n", + "return original_images, true_labels, predicted_labels\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "class Trainer():\n", + " def __init__(self, model):\n", + " pass\n", + " \n", + " def fit(self, epochs, train_dataloader, val_dataloader, optimizer, criterion, device, early_stopping_limit = 0):\n", + " return\n", + " \n", + " def predict(self, test_dataloader, device):\n", + " return\n" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "## Functions\n", + "\n", + "The following three functions are going to be used throughout the notebook.\n", + "They comprise the loading of binary files using Pickle (**load_pickle_file**), single image plotting (**plot_image**), and multiple image plotting (**plot_multiple_images**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "import pickle\n", + "\n", + "def load_pickle_file(filepath):\n", + " with open(filepath, \"rb\") as f:\n", + " return pickle.load(f)\n", + "\n", + "def plot_image(img, figsize = (2,3)):\n", + " plt.figure(figsize = figsize)\n", + " plt.imshow(img)\n", + " plt.axis(\"off\")\n", + " \n", + "def plot_multiple_images(*images_titles, figsize = (2, 3)):\n", + " num_images = len(images_titles)\n", + " fig, axs = plt.subplots(1, num_images, figsize = figsize)\n", + " for i in range(num_images):\n", + " axs[i].imshow(images_titles[i][0])\n", + " axs[i].set_title(images_titles[i][1])\n", + " axs[i].axis(\"off\")" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "In this section, we load the CIFAR-10 dataset, which consists of 60,000 32x32 color images across 10 different classes, with 6,000 images per class.\n", + "The dataset is divided into 50,000 training images and 10,000 test images. It was already processed and it is ready to use after loading the binary files *train_set.pkl* and *test_set.pkl*." + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "### Load data\n", + "\n", + "Training and test sets are loaded using Pickle library. If you do not have the dataset already, open this [link](https://www.dropbox.com/scl/fo/p7gfb0kpgkbrrjup340pi/AAkX2u1g-W7290-Aq7gHHvo?rlkey=vdxaj6npfy09ywh17nl8f9v6e&st=8hfq9z20&dl=0) and download it.\n", + "Place it inside the data folder." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Sets filepaths\n", + "dataset_folder = os.path.join(\"data/CIFAR10\")\n", + "train_set_file = os.path.join(dataset_folder, \"train_set.pkl\")\n", + "test_set_file = os.path.join(dataset_folder, \"test_set.pkl\")\n", + "\n", + "# Load sets\n", + "train_set = load_pickle_file(train_set_file)\n", + "test_set = load_pickle_file(test_set_file)\n", + "\n", + "# CIFAR10 classes\n", + "CIFAR_10_CLASSES = [\n", + " \"Airplane\", \"Automobile\", \"Bird\", \"Cat\", \"Deer\",\n", + " \"Dog\", \"Frog\", \"Horse\",\"Ship\",\"Truck\"\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "## Explore image processing\n", + "\n", + "Image processing is fundamental to computer vision, forming the basis for interpreting and analyzing visual information.\n", + "By applying techniques such as resizing, filtering, color adjustments, and data augmentation, image processing enhances input quality, minimizes noise, and corrects distortions.\n", + "These methods can also simulate real-world variability, helping models generalize better. \n", + "\n", + "In this notebook, we explore three categories of image transformations: **geometric transformations**, **image filtering**, and **photometric transformations**.\n", + "The following cells contain a series of exercicies designed to help you explore the OpenCV-Python library.\n", + "If you are unfamiliar with a particular method, refer to the [Image Processing in OpenCV](https://docs.opencv.org/4.x/d2/d96/tutorial_py_table_of_contents_imgproc.html) documentation." + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "### Example image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Select image\n", + "image = train_set[0][9]\n", + "\n", + "# Convert image from (C, H, W) to (H, W, C)\n", + "image = np.transpose(image, (1,2,0))\n", + "\n", + "# Plot image\n", + "plot_image(image)" + ] + }, + { + "cell_type": "markdown", + "id": "24", + "metadata": {}, + "source": [ + "### Geometric transformation\n", + "\n", + "Geometric transformations alter the spatial structure of the image while preserving its semantic content.\n", + "They help the model become invariant to different orientations and scales:\n", + "\n", + "- **Scaling**: Resizes the image to a specific size, often required to match input dimensions for image classifiers.\n", + " It uses interpolation to obtain the new pixel-values.\n", + "- **Cropping**: Extracts a subregion of the image; useful for focusing on important parts or adding variability.\n", + "- **Horizontal and vertical flip**: Flips the image along the x-axis or y-axis; helps the model learn symmetry.\n", + "- **Rotation**: Rotates the image by a small angle to simulate different orientations of the objects." + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "#### Scaling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_scale_image(img, scale_factor: float):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "# Scale image by half\n", + "scaled_image = solution_scale_image(image, 0.5)\n", + "\n", + "if scaled_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (scaled_image, \"Scaled\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "#### Cropping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_crop_image(img, x: int, y: int, width: int, height: int):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "# Crop image to get a 15-by-15 image starting on (x,y): (2,2)\n", + "cropped_image = solution_crop_image(image, 2, 2, 15, 15)\n", + "\n", + "if cropped_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (cropped_image, \"Cropped\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "33", + "metadata": {}, + "source": [ + "#### Horizontal Flip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_horizontal_flip_image(img):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "# Flip image horizontally\n", + "flip_image_horizontal = solution_horizontal_flip_image(image)\n", + "\n", + "if flip_image_horizontal is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (flip_image_horizontal, \"Horizontal Flip\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "#### Vertical Flip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_vertical_flip_image(img):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "# Flip image vertically\n", + "flip_image_vertical = solution_vertical_flip_image(image)\n", + "\n", + "if flip_image_vertical is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (flip_image_vertical, \"Vertical Flip\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": {}, + "source": [ + "#### Rotation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_rotate_image(img, angle: float):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "# Rotate image by 20 degrees\n", + "rotated_image = solution_rotate_image(image, 20)\n", + "\n", + "if rotated_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (rotated_image, \"Rotated\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "45", + "metadata": {}, + "source": [ + "### Image filtering\n", + "\n", + "Filtering helps reduce noise and enhance specific image features.\n", + "These are often used as a form of preprocessing before feeding images into a model:\n", + "\n", + "- **Average filter**: Applies a smoothing effect by replacing each pixel with the average of its neighborhood.\n", + "- **Median filter**: Reduces salt-and-pepper noise by replacing each pixel with the median of neighboring pixels.\n", + "- **Gaussian filter**: Applies a Gaussian blur to smooth the image, often used to reduce high-frequency noise." + ] + }, + { + "cell_type": "markdown", + "id": "46", + "metadata": {}, + "source": [ + "#### Average filter " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_average_filter(img, kernel_size = (5, 5)):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [], + "source": [ + "# Filter image using average filter\n", + "average_filter_image = solution_average_filter(image, (3, 3))\n", + "\n", + "if average_filter_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (average_filter_image, \"Average filter\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "50", + "metadata": {}, + "source": [ + "#### Median filter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_median_filter(img, ksize):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [], + "source": [ + "# Filter image using median filter\n", + "median_filter_image = solution_median_filter(image, 3)\n", + "\n", + "if median_filter_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (median_filter_image, \"Median filter\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "54", + "metadata": {}, + "source": [ + "#### Gaussian filter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_gaussian_filter(img, kernel_size = (5, 5), sigma = 0):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [], + "source": [ + "# Filter image using Gaussian filter\n", + "gaussian_filter_image = solution_gaussian_filter(image, (7, 7), 0)\n", + "\n", + "if gaussian_filter_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (gaussian_filter_image, \"Gaussian filter\"), figsize = (4, 5))" + ] + }, + { + "cell_type": "markdown", + "id": "58", + "metadata": {}, + "source": [ + "### Photometric transformation\n", + "\n", + "Photometric transformations modify the color properties of an image to simulate different lighting conditions and improve model robustness to brightness and contrast changes:\n", + "\n", + "- **Brightness**: Randomly increases or decreases the brightness of the image.\n", + "- **Contrast**: Alters the difference between light and dark regions in the image.\n", + "- **Saturation**: Modifies the intensity of the colors in the image." + ] + }, + { + "cell_type": "markdown", + "id": "59", + "metadata": {}, + "source": [ + "#### Adjust brightness" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_adjust_brightness(img, brightness_value):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62", + "metadata": {}, + "outputs": [], + "source": [ + "# Brighter image (positive brightness value)\n", + "brighter_image = solution_adjust_brightness(image, 100)\n", + "\n", + "# Darker image (negative brightness value)\n", + "darker_image = solution_adjust_brightness(image, -100)\n", + "\n", + "if brighter_image is not None and darker_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (brighter_image, \"Brighter image\"), (darker_image, \"Darker image\"), figsize = (7, 8))" + ] + }, + { + "cell_type": "markdown", + "id": "63", + "metadata": {}, + "source": [ + "#### Adjust contrast" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_adjust_contrast(img, contrast_value):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66", + "metadata": {}, + "outputs": [], + "source": [ + "# Increase contrast (Value > 1.0)\n", + "high_contrast_image = solution_adjust_contrast(image, 2.0)\n", + "\n", + "# Reduce contrast (Value < 1.0)\n", + "low_contrast_image = solution_adjust_contrast(image, 0.5)\n", + "\n", + "if high_contrast_image is not None and low_contrast_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (high_contrast_image, \"High contrast image\"), (low_contrast_image, \"Low contrast image\"), figsize = (7, 8))" + ] + }, + { + "cell_type": "markdown", + "id": "67", + "metadata": {}, + "source": [ + "#### Adjust saturation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext tutorial.tests.testsuite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69", + "metadata": {}, + "outputs": [], + "source": [ + "%%ipytest\n", + "\n", + "import cv2\n", + "def solution_adjust_saturation(img, saturation_factor):\n", + " # Start your code here\n", + " return\n", + " # End your code here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70", + "metadata": {}, + "outputs": [], + "source": [ + "# Decrease saturation\n", + "low_saturation_image = solution_adjust_saturation(image, 0.2)\n", + "\n", + "# Increase saturation\n", + "high_saturation_image = solution_adjust_saturation(image, 2.5)\n", + "\n", + "if low_saturation_image is not None and high_saturation_image is not None:\n", + " # Use this function to plot images side by side\n", + " plot_multiple_images((image, \"Original\"), (low_saturation_image, \"Low saturation image\"), (high_saturation_image, \"High saturation image\"), figsize = (7, 8))" + ] + }, + { + "cell_type": "markdown", + "id": "71", + "metadata": {}, + "source": [ + "## Image classifier development using CNNs\n", + "\n", + "Image classification is the task of assigning a label or category to an input image from a predefined set of classes.\n", + "It is a fundamental problem in computer vision with widespread applications, including facial recognition, medical imaging, quality control, and autonomous driving. \n", + "This section outlines the key steps involved in developing an image classification model using PyTorch:\n", + "\n", + "- It begins with data preprocessing, which includes splitting the dataset into training, validation, and test sets. \n", + "- Afterwards, it defines data augmentation strategies using the Albumentations library, loads the data as PyTorch datasets, and initialises PyTorch dataloaders to efficiently feed data during training. \n", + "- The next step is model training, where a CNN-based model is initialized and optimised using the training and validation data. \n", + "- After training, the model is evaluated on the test set to assess its performance.\n", + " The evaluation includes metrics such as accuracy and the confusion matrix, which help interpret the model's predictive behavior. \n", + "- Finally, the PyTorch Grad-CAM library is used to visualize the regions of input images that contribute most to the model’s decisions, providing insights into model explainability using representative examples." + ] + }, + { + "cell_type": "markdown", + "id": "72", + "metadata": {}, + "source": [ + "### Dataset preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "73", + "metadata": {}, + "source": [ + "#### Train, validation, and test sets\n", + "\n", + "```train_test_split``` from Scikit-learn can be used to split the original training set into training and validation sets.\n", + "The test set is already defined by the dataset' authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Train and validation sets\n", + "X_train, y_train = train_set[0], train_set[1]\n", + "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size = 0.3, random_state = 42)\n", + "\n", + "# Test set\n", + "X_test, y_test = test_set[0], test_set[1]" + ] + }, + { + "cell_type": "markdown", + "id": "75", + "metadata": {}, + "source": [ + "#### Data Augmentation\n", + "\n", + "Data augmentation is a crucial technique in image classification that helps improve the performance and robustness of machine learning models.\n", + "It involves generating new training samples by applying random transformations — such as rotation, flipping, cropping, scaling, or color jittering — to the original images. \n", + "\n", + "Albumentations is one of the most widely used libraries for performing data augmentation in image classification tasks.\n", + "It includes augmentation techniques that replicate operations commonly used in image processing, such as:\n", + "\n", + "- ```A.Affine``` for scaling;\n", + "\n", + "- ```A.Rotate``` for rotation;\n", + "\n", + "- ```A.HorizontalFlip``` for horizontal flipping;\n", + "\n", + "- ```A.VerticalFlip``` for vertical flipping;\n", + "\n", + "- ```A.ColorJitter``` for color jittering.\n", + "\n", + "Albumentations can also be used for image normalization (```A.Normalize```), resizing (```A.Resize```), and converting images to PyTorch tensors with the (Channel, Height, Width) format using ```A.ToTensorV2```, which is required for model training.\n", + "Apply the following transformations only to the training set, as the validation set should remain as close as possible to the test set. Therefore, no transformations should be applied to it.\n", + "\n", + "```python\n", + "A.Affine(scale = (0.2, 1.5), p = 0.1),\n", + "A.Rotate(limit = 45, p = 0.1),\n", + "A.HorizontalFlip(p = 0.1),\n", + "A.VerticalFlip(p = 0.1),\n", + "A.ColorJitter(brightness = (0.5, 1.5), contrast = (0.5, 1.5), saturation = (0.5, 1.5), hue = (0,0), p = 0.1)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76", + "metadata": {}, + "outputs": [], + "source": [ + "import albumentations as A\n", + "\n", + "# Transformations performed on train set\n", + "TARGET_SIZE = 32\n", + "train_transform = A.Compose([\n", + " A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)),\n", + " A.Resize(height = TARGET_SIZE, width = TARGET_SIZE),\n", + " A.ToTensorV2()\n", + "])\n", + "\n", + "# Transformations performed on validation and test sets\n", + "val_transform = A.Compose([\n", + " A.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)),\n", + " A.Resize(height = TARGET_SIZE, width = TARGET_SIZE),\n", + " A.ToTensorV2()\n", + "])" + ] + }, + { + "cell_type": "markdown", + "id": "77", + "metadata": {}, + "source": [ + "#### PyTorch Datasets\n", + "\n", + "```ImageDataset``` class is based on PyTorch ```Dataset``` class and is used for loading the images and their corresponding labels, for applying transformations (such as data augmentation), and returns them in a format suitable for model training, validating, and testing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78", + "metadata": {}, + "outputs": [], + "source": [ + "# Dataset classes necessary for the data loaders\n", + "train_dataset = ImageDataset(X_train, y_train, transform = train_transform)\n", + "val_dataset = ImageDataset(X_val, y_val, transform = val_transform)\n", + "test_dataset = ImageDataset(X_test, y_test, transform = val_transform)" + ] + }, + { + "cell_type": "markdown", + "id": "79", + "metadata": {}, + "source": [ + "#### PyTorch Dataloaders\n", + "\n", + "```DataLoader``` is essential for training efficiency and performance.\n", + "It abstracts the complexity of batching, shuffling, and parallel data access, allowing you to focus on building and training your models.\n", + "```batch_size``` specifies the number of samples processed in parallel during each training iteration.\n", + "It is typically treated as a hyperparameter, as its optimal value depends on hardware constraints (e.g., GPU memory) and its interaction with training dynamics.\n", + "Notably, it is often linearly related with the learning rate. Larger batch sizes generally require proportionally larger learning rates to maintain stable and efficient convergence.\n", + "```shuffle``` controls whether the dataset is randomly permuted at the start of each epoch.\n", + "Enabling ```shuffle = True``` is typically beneficial, as it helps prevent the model from learning misleading patterns due to class-wise ordering in the dataset, which could hinder generalization and convergence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "# Data loaders needed for the model training\n", + "BATCH_SIZE = 64\n", + "train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = True)\n", + "test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True)" + ] + }, + { + "cell_type": "markdown", + "id": "81", + "metadata": {}, + "source": [ + "### Model training\n", + "\n", + "Model training comprises a series of steps:\n", + "\n", + "1. First, we must check which devices are available for training the model.\n", + " In case a GPU with Cuda cores is available is should be used as it really improves the speed.\n", + " Otherwise, lets use CPU. \n", + "1. Then, model and training hyperparameters should be defined, such as numer of output classes, number of training epochs, number of consecutive not improving epochs needed for stopping the training in case we use early stopping regularisation, and learning rate.\n", + " Other hyperparameters can be defined, it depends on what the user wants to do during the training.\n", + " In this notebook we are going to define the number of epochs, which are the number of times the model is going to see the training set.\n", + " Early stopping is a way of trying to avoid overfitting where the model evaluates the model every new epoch using a validation set.\n", + " In case the loss obtained for the validation set does not decrease for a long period of time (pre-defined epochs), the model optimisation stops and retrieves the checkpoint where the validation loss got the last decrease (see [Early Stopping](https://paperswithcode.com/method/early-stopping)).\n", + " Learning rate defined how quick the models weights should change during training.\n", + " If it is too high the weights are going to change really quick and might miss minima because they are always jumping from one side to another side.\n", + " If it is too small the model weights might get stuck a local minimum.\n", + " So although this is not done in this notebook, this parameter should be studied in order to choose the best (see [What is learning rate in machine learning?](https://www.ibm.com/think/topics/learning-rate)). \n", + "1. After defining the hyperparameters, we should define the loss function that is going to be used to evaluate the model and it should be sent to the hardware used for training. \n", + "1. Afterwards, the model is defined using ```ImageClassifier``` class and is sent to the device used for training.\n", + "1. Next, we should define the optimiser function and also send it to the device used for training.\n", + "1. Afterwards, we train the model in case some optimised weights are not available and we explore the learning curves." + ] + }, + { + "cell_type": "markdown", + "id": "82", + "metadata": {}, + "source": [ + "### Model Training Overview\n", + "\n", + "Model training involves a sequence of key steps.\n", + "The first step is to check which computational devices are available.\n", + "If a GPU with CUDA cores is accessible, it should be used, as it significantly accelerates training (```DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")```).\n", + "Otherwise, the model will be trained on the CPU. Next, we define the model and training hyperparameters.\n", + "These typically include:\n", + "\n", + "- The number of output classes (```NUMBER_CLASSES = len(CIFAR_10_CLASSES)```);\n", + "- The number of training epochs (i.e., how many times the model sees the full training set) (```EPOCHS = 500```);\n", + "- The patience for early stopping (i.e., how many consecutive epochs without improvement are allowed before stopping training) (```EARLY_STOPPING_LIMIT = EPOCHS // 10```);\n", + "- The learning rate (```LR = 0.001```).\n", + "\n", + "Additional hyperparameters may also be configured depending on the training strategy or specific use case.\n", + "\n", + "In this notebook, we focus on setting the number of training epochs. We also discuss **early stopping**, a regularization technique used to prevent overfitting.\n", + "During training, the model's performance is evaluated on a validation set at the end of each epoch.\n", + "If the validation loss does not improve after a predefined number of epochs, training is stopped, and the model reverts to the best-performing checkpoint (see [Early Stopping](https://paperswithcode.com/method/early-stopping)).\n", + "The **learning rate** controls how quickly the model updates its weights during training.\n", + "If it's too high, the model may overshoot optimal loss values, leading to instability.\n", + "If it's too low, the model may converge very slowly or get stuck in a local minimum.\n", + "Although learning rate tuning is not performed in this notebook, it is an essential hyperparameter that should be carefully selected (see [What is learning rate in machine learning?](https://www.ibm.com/think/topics/learning-rate)).\n", + "\n", + "After setting the hyperparameters, we define the **loss function** used to evaluate model performance (```criterion = nn.CrossEntropyLoss()```).\n", + "Both the model and loss function should be moved to the selected training device (```criterion = criterion.to(DEVICE)```).\n", + "The model is then instantiated using the `ImageClassifier` class (```model = ImageClassifier(in_channels = 3, out_classes = NUMBER_CLASSES)```) and transferred to the training device (```model = model.to(DEVICE)```).\n", + "The **optimizer** is also defined and configured on the same device (```optimizer = optim.Adam(model.parameters(), lr = LR)```).\n", + "\n", + "Finally, if no pre-trained weights are available, the training process begins, and we monitor the learning curves to assess the model’s performance over time." + ] + }, + { + "cell_type": "markdown", + "id": "83", + "metadata": {}, + "source": [ + "#### Check which device is used for training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Check which device is available for training the model\n" + ] + }, + { + "cell_type": "markdown", + "id": "85", + "metadata": {}, + "source": [ + "#### Define training hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86", + "metadata": {}, + "outputs": [], + "source": [ + "# Get number of output classes\n", + "NUMBER_CLASSES = len(CIFAR_10_CLASSES)\n", + "\n", + "# Set the number of training epochs\n", + "\n", + "# Set the number of consecutive not improving epochs needed for stopping the training\n", + "\n", + "# Set the learning rate" + ] + }, + { + "cell_type": "markdown", + "id": "87", + "metadata": {}, + "source": [ + "#### Loss function\n", + "\n", + "The cross entropy loss function is defined by:\n", + "\n", + "$$\n", + "\\mathcal{L} = -\\sum_{i=1}^{C} y_i \\log(\\hat{y}_i)\n", + "$$\n", + "\n", + "Where:\n", + "\n", + "$\\mathcal{L}$: Cross-entropy loss\n", + "\n", + "$C$: Total number of classes\n", + "\n", + "$y_i$: Ground truth indicator for class $i$, where $y_i = 1$ if class $i$ is the correct class, otherwise $y_i = 0$\n", + "\n", + "$\\hat{y}_i$: Predicted probability for class $i$, typically from the softmax output, where $0 \\leq \\hat{y}_i \\leq 1$ and $\\sum_{i=1}^{C} \\hat{y}_i = 1$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "# Initialise the Cross Entropy Loss and send it to the training device\n" + ] + }, + { + "cell_type": "markdown", + "id": "89", + "metadata": {}, + "source": [ + "#### Initialise model architecture" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise image classifier and send it to the training device" + ] + }, + { + "cell_type": "markdown", + "id": "91", + "metadata": {}, + "source": [ + "#### Optimiser function\n", + "\n", + "In this notebook, we are using Adam optimiser (```optimizer = optim.Adam(model.parameters(), lr = LR)```) which is one of the most used optimisers in deep neural network optimisation (see [Gentle Introduction to the Adam Optimisation Algorithm for Deep Learning](https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/)).\n", + "\n", + "The parameter update at each step is given by:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "m_t &= \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\\n", + "v_t &= \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\\n", + "\\hat{m}_t &= \\frac{m_t}{1 - \\beta_1^t} \\\\\n", + "\\hat{v}_t &= \\frac{v_t}{1 - \\beta_2^t} \\\\\n", + "\\theta_t &= \\theta_{t-1} - \\alpha \\frac{\\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon}\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "Where:\n", + "\n", + "$\\theta_t$: Parameters at time step $t$\n", + "\n", + "$g_t$: Gradient of the loss with respect to parameters at step $t$\n", + "\n", + "$m_t$: Exponentially decaying average of past gradients (1st moment)\n", + "\n", + "$v_t$: Exponentially decaying average of past squared gradients (2nd moment)\n", + "\n", + "$\\hat{m}_t$, $\\hat{v}_t$: Bias-corrected estimates of $m_t$ and $v_t$\n", + "\n", + "$\\alpha$: Learning rate\n", + "\n", + "$\\beta_1$: Decay rate for the first moment estimate (typically 0.9)\n", + "\n", + "$\\beta_2$: Decay rate for the second moment estimate (typically 0.999)\n", + "\n", + "$\\epsilon$: Small constant to prevent division by zero (e.g., 1e-8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "\n", + "# Initialise the Adam optimiser\n" + ] + }, + { + "cell_type": "markdown", + "id": "93", + "metadata": {}, + "source": [ + "#### Train model\n", + "\n", + "Here, we train the model.\n", + "First, we initialise the class ```Trainer``` which we are going to use for training and evaluating the model using the PyTorch ```Dataset```s defined before (```trainer = Trainer(model)```).\n", + "In case, some model weights are already available, we can skip the training and using them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialise the Train instance, which is going to be used to train the image classifier" + ] + }, + { + "cell_type": "markdown", + "id": "95", + "metadata": {}, + "source": [ + "After initialising the trainer instance, check whether a trained model already exists.\n", + "If so, load the weights using ```model_weights = torch.load(model_path, weights_only=True)```.\n", + "Then, load the weights into the model using (```model.load_state_dict(model_weights)```).\n", + "Finally, set the model to evaluation model (```model.eval()```).\n", + "This step is essential because certain layers, such as batch normalization and dropout, behave differently during training and evaluation.\n", + "Setting the model to evaluation mode ensures they operate correctly during validation or testing. \n", + "\n", + "If no pre-trained model is available, train a new model using the training and validation sets along with the predefined hyperparameters (```trainer.fit(EPOCHS, train_dataloader, val_dataloader, optimizer, criterion, DEVICE, EARLY_STOPPING_LIMIT)```).\n", + "After training, save the best model weights (```torch.save(trainer.model.best_model_weights, model_path)```)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "\n", + "# Model filename\n", + "model_path = \"cnn_weights.pt\"\n", + "\n", + "if os.path.exists(model_path):\n", + " pass\n", + "else:\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "id": "97", + "metadata": {}, + "source": [ + "#### Learning curves\n", + "\n", + "After training the model, we can analyse the learning curves to assess the training process.\n", + "These curves, which typically display the loss over epochs for both the training and validation sets, are crucial for improving model performance.\n", + "They can help identify issues like overfitting or underfitting.\n", + "Overfitting occurs when the model performs well on the training data but poorly on the validation data, usually indicated by a widening gap between the two curves.\n", + "Underfitting, on the other hand, is suggested when both the training and validation curves show poor performance and fail to improve. By monitoring these curves, we can adjust hyperparameters or modify the model architecture to address such issues. \n", + "\n", + "First, load the log file using ```pandas``` (```training_log = pd.read_csv(\"training_log.txt\")```).\n", + "Then, use the ```matplotlib``` library to plot the learning curves." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# Load the training log file\n", + "training_log = None\n", + "\n", + "plt.figure()\n", + "plt.plot(training_log[\"Train Loss\"])\n", + "plt.plot(training_log[\"Val Loss\"])\n", + "plt.legend([\"Train loss\", \"Val loss\"])\n", + "plt.xlabel(\"Epochs\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.title(\"Learning curve of image classification model\")" + ] + }, + { + "cell_type": "markdown", + "id": "99", + "metadata": {}, + "source": [ + "### Model testing\n", + "\n", + "Once the model has been trained and optimized, we can evaluate its performance using the ```Trainer``` class by calling the ```predict``` method with the test dataloader and the device:\n", + "\n", + "```python\n", + "original_images, true_labels, predicted_labels = trainer.predict(test_dataloader, DEVICE)\n", + "```\n", + "\n", + "This method returns three NumPy arrays:\n", + "\n", + "- ```original_images```: the input images from the test set\n", + "- ```true_labels```: the corresponding ground truth labels\n", + "- ```predicted_labels```: the model's predicted classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "100", + "metadata": {}, + "outputs": [], + "source": [ + "# Write here the line of code to predict the labels for the test set" + ] + }, + { + "cell_type": "markdown", + "id": "101", + "metadata": {}, + "source": [ + "### Explore results\n", + "\n", + "To evaluate the results, we display the model's accuracy along with the confusion matrix.\n", + "The confusion matrix is a powerful evaluation tool that helps us understand the model’s performance across multiple classes.\n", + "It maps the relationship between true and predicted labels, showing the number of instances for each possible prediction-outcome pair." + ] + }, + { + "cell_type": "markdown", + "id": "102", + "metadata": {}, + "source": [ + "#### Compute average accuracy\n", + "\n", + "To compute the accuracy, get the number of test samples (```num_test_samples = len(original_images)```), check how many samples were correctly classified (```correct = (true_labels == predicted_labels).sum()```), and get the ratio (```accuracy = correct/num_test_samples```)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "103", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute average accuracy\n", + "accuracy = 0.0\n", + "print(\"Accuracy:\", accuracy)" + ] + }, + { + "cell_type": "markdown", + "id": "104", + "metadata": {}, + "source": [ + "#### Compute confusion matrix\n", + "\n", + "To compute the confusion matrix, use ```confusion_matrix``` from scikit-learn library:\n", + "\n", + "```python\n", + "cm = confusion_matrix(true_labels, predicted_labels)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "105", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import confusion_matrix\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# Compute confusion matrix\n", + "cm = None\n", + "\n", + "# Plot confusion matrix\n", + "fig, ax = plt.subplots(figsize=(10, 8))\n", + "cax = ax.matshow(cm, cmap='Greens')\n", + "\n", + "# Add labels, title, and ticks\n", + "ax.set_xticks(np.arange(NUMBER_CLASSES))\n", + "ax.set_yticks(np.arange(NUMBER_CLASSES))\n", + "ax.set_xticklabels(CIFAR_10_CLASSES)\n", + "ax.set_yticklabels(CIFAR_10_CLASSES)\n", + "plt.xlabel('Predicted Labels')\n", + "plt.ylabel('True Labels')\n", + "plt.title('Confusion Matrix for test set of CIFAR10')\n", + "\n", + "# Annotate each cell with the numeric value\n", + "for (i, j), val in np.ndenumerate(cm):\n", + " ax.text(j, i, f'{val}', ha='center', va='center', color='black')\n", + "\n", + "# Rotate class names on x-axis\n", + "plt.xticks(rotation=45)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "106", + "metadata": {}, + "source": [ + "### Explain image classifier predictions\n", + "\n", + "Deep neural networks are often described as \"black boxes\" because their decision-making processes are difficult to understand and interpret.\n", + "To address this, researchers have developed various methods to make these models more explainable.\n", + "One such method is Grad-CAM (Gradient-weighted Class Activation Mapping).\n", + "Grad-CAM computes the gradients of a target class with respect to the final convolutional layers and generates a heatmap that highlights the regions of the input image most influential in the model’s prediction for that class." + ] + }, + { + "cell_type": "markdown", + "id": "107", + "metadata": {}, + "source": [ + "#### Prepare image for Grad-CAM\n", + "\n", + "To prepare the image for Grad-CAM visualization:\n", + "\n", + "- First, convert it to (Height, Width, Channels) format using ```img_np = np.transpose(img, (1, 2, 0)) # shape: (H, W, C)```, and normalize its values to the [0, 1] range with ```img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())```.\n", + " This processed image is used only for visualization, as expected by the PyTorch-GradCAM library.\n", + "- Next, modify the original image for model inference by adding a batch dimension: ```img = np.expand_dims(img, axis=0)```, then convert it to a PyTorch tensor: ```img = torch.from_numpy(img)```, and move it to the appropriate computation device using ```img = img.to(DEVICE)```.\n", + "- Finally, retrieve the predicted and true labels, as both are required for computing and visualizing the Grad-CAM output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "108", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "# Get a batch of images\n", + "idx = 1\n", + "img = original_images[idx]\n", + "pred_label = predicted_labels[idx]\n", + "true_label = true_labels[idx]" + ] + }, + { + "cell_type": "markdown", + "id": "109", + "metadata": {}, + "source": [ + "#### Compute GradCAM heatmap\n", + "\n", + "To compute the Grad-CAM heatmap:\n", + "\n", + "- First, ensure that the `requires_grad` attribute of the input image tensor is set to `True` by using `img.requires_grad = True`.\n", + " This enables gradient computation with respect to the image, which is necessary for generating class activation maps.\n", + "- Next, specify the layer to inspect using ```target_layers = [model.conv3]```.\n", + "- Typically, the last convolutional layer of the image classifier is chosen because it preserves spatial information, which is crucial for identifying the regions of the input image that most strongly influence the model's prediction. \n", + "- Then, define the target class to be explained with ```targets = [ClassifierOutputTarget(pred_label)]```, where ```pred_label``` is the class index corresponding to the model’s predicted output (or any other class of interest).\n", + "- Finally, compute the Grad-CAM heatmap using the activations and gradients from the selected layer:\n", + "```python\n", + "# Create CAM object\n", + "with GradCAM(model=model, target_layers=target_layers) as cam:\n", + " grad_cam_matrix = cam(input_tensor=img, targets=targets)\n", + " grad_cam_matrix = grad_cam_matrix[0, :]\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "110", + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_grad_cam import GradCAM\n", + "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n", + "\n", + "# Make sure input requires grad\n", + "\n", + "# Define the layer(s) to inspect\n", + "\n", + "# Define the target class you want to explain\n", + "\n", + "# Compute CAM object" + ] + }, + { + "cell_type": "markdown", + "id": "111", + "metadata": {}, + "source": [ + "#### Visualise Grad-CAM heatmap with the image\n", + "\n", + "After obtaining the Grad-CAM heatmap, we overlay it on the input image to visualise the regions that contributed most to the model’s prediction (```visualisation = show_cam_on_image(img_np, grad_cam_matrix, use_rgb=True)```).\n", + "This helps identify which pixels the model focused on when predicting the class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "112", + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_grad_cam.utils.image import show_cam_on_image\n", + "\n", + "# Combine CAM with image\n", + "visualisation = None\n", + "\n", + "# Plot image with GradCAM output\n", + "true_class = CIFAR_10_CLASSES[true_label]\n", + "pred_class = CIFAR_10_CLASSES[pred_label]\n", + "plot_multiple_images((img_np, f\"Original - {true_class}\"), (visualisation, f\"Grad-CAM - {pred_class}\"), figsize = (5,6))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorial/tests/test_31_image_classification.py b/tutorial/tests/test_31_image_classification.py new file mode 100644 index 00000000..d5e1eccc --- /dev/null +++ b/tutorial/tests/test_31_image_classification.py @@ -0,0 +1,175 @@ +import cv2 +import numpy as np +import pytest + + +def reference_scale_image(image, scale_factor): + # Get the current dimensions + height, width = image.shape[:2] + + # Calculate the new dimensions + new_width = int(width * scale_factor) + new_height = int(height * scale_factor) + new_size = (new_width, new_height) + + # Resize the image + return cv2.resize(image, new_size) + + +@pytest.mark.parametrize("scale_factor", [0.5, 1.0, 2.0]) +def test_scale_image(scale_factor, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, scale_factor) + image_reference = reference_scale_image(image, scale_factor) + assert image_test.shape == image_reference.shape + + +def reference_crop_image(image, x: int, y: int, width: int, height: int): + x1, x2, y1, y2 = x, x + width, y, y + height + return image[y1:y2, x1:x2] + + +@pytest.mark.parametrize( + "x, y, width, height", [(2, 2, 2, 2), (5, 5, 4, 4), (10, 10, 6, 6)] +) +def test_crop_image(x, y, width, height, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, x, y, width, height) + image_reference = reference_crop_image(image, x, y, width, height) + assert image_test.shape == image_reference.shape + + +def reference_horizontal_flip_image(image): + return cv2.flip(image, 1) + + +def test_horizontal_flip_image(function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image) + image_reference = reference_horizontal_flip_image(image) + assert np.allclose(image_test, image_reference) + + +def reference_vertical_flip_image(image): + return cv2.flip(image, 0) + + +def test_vertical_flip_image(function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image) + image_reference = reference_vertical_flip_image(image) + assert np.allclose(image_test, image_reference) + + +def reference_rotate_image(image, angle: float): + (h, w) = image.shape[:2] + center = (w // 2, h // 2) + mat = cv2.getRotationMatrix2D(center, angle, scale=1.0) + + # Compute new bounding dimensions + cos = np.abs(mat[0, 0]) + sin = np.abs(mat[0, 1]) + new_w = int((h * sin) + (w * cos)) + new_h = int((h * cos) + (w * sin)) + + # Adjust rotation matrix for translation + mat[0, 2] += (new_w / 2) - center[0] + mat[1, 2] += (new_h / 2) - center[1] + + # Perform rotation with expanded canvas + return cv2.warpAffine(image, mat, (new_w, new_h)) + + +@pytest.mark.parametrize("angle", [5, 10, 20, 30]) +def test_rotate_image(angle, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, angle) + image_reference = reference_rotate_image(image, angle) + assert np.allclose(image_test, image_reference) + + +def reference_average_filter(image, kernel_size): + return cv2.blur(image, kernel_size) + + +@pytest.mark.parametrize("kernel_size", [(3, 3), (5, 5)]) +def test_average_filter(kernel_size, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, kernel_size) + image_reference = reference_average_filter(image, kernel_size) + assert np.allclose(image_test, image_reference) + + +def reference_median_filter(image, ksize): + return cv2.medianBlur(image, ksize) + + +@pytest.mark.parametrize("ksize", [3, 5]) +def test_median_filter(ksize, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, ksize) + image_reference = reference_median_filter(image, ksize) + assert np.allclose(image_test, image_reference) + + +def reference_gaussian_filter(image, kernel_size, sigma): + return cv2.GaussianBlur(image, kernel_size, sigma) + + +@pytest.mark.parametrize( + "kernel_size, sigma", [((3, 3), 0), ((5, 5), 0), ((3, 3), 1), ((5, 5), 1)] +) +def test_gaussian_filter(kernel_size, sigma, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, kernel_size, sigma) + image_reference = reference_gaussian_filter(image, kernel_size, sigma) + assert np.allclose(image_test, image_reference) + + +def reference_adjust_brightness(image, brightness_value): + return cv2.convertScaleAbs(image, beta=brightness_value) + + +@pytest.mark.parametrize("brightness_value", [-30, -20, -10, 0, 10, 20, 30]) +def test_adjust_brightness(brightness_value, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, brightness_value) + image_reference = reference_adjust_brightness(image, brightness_value) + assert np.allclose(image_test, image_reference) + + +def reference_adjust_contrast(image, contrast_value): + return cv2.convertScaleAbs(image, alpha=contrast_value) + + +@pytest.mark.parametrize("contrast_value", [0.5, 1.0, 1.5, 2.0]) +def test_adjust_contrast(contrast_value, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, contrast_value) + image_reference = reference_adjust_contrast(image, contrast_value) + assert np.allclose(image_test, image_reference) + + +def reference_adjust_saturation(image, saturation_factor): + # Convert the image from BGR to HSV + image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # Split the HSV image into Hue, Saturation, and Value channels + hue, saturation, value = cv2.split(image_hsv) + + # Adjust the saturation channel (Ensure it stays within valid range) + saturation = np.clip(saturation * saturation_factor, 0, 255) + + # Merge the channels back + image_hsv_adjusted = cv2.merge([hue, saturation.astype(np.uint8), value]) + + # Convert the adjusted image back to BGR + return cv2.cvtColor(image_hsv_adjusted, cv2.COLOR_HSV2RGB) + + +@pytest.mark.parametrize("saturation_factor", [0.5, 1.0, 1.5, 2.0]) +def test_adjust_saturation(saturation_factor, function_to_test): + image = np.ones((32, 32, 3), dtype=np.uint8) * 255 + image_test = function_to_test(image, saturation_factor) + image_reference = reference_adjust_saturation(image, saturation_factor) + assert np.allclose(image_test, image_reference)