From 09b02cdec14afcd6effe0cd42915240fce03ab65 Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Mon, 23 Feb 2026 17:47:46 -0500 Subject: [PATCH] ENH: Add ngff-zarr multi-resolution registration example Add Example 23 demonstrating a workflow that combines ngff-zarr multi-resolution image pyramids with ITKElastix registration: - Convert ITK images to ngff-zarr Multiscales via to_multiscales - Register at a coarse resolution (rigid + affine + bspline) - Convert Elastix results to itk.CompositeTransform - Apply the transform at full resolution in parallel using dask.array.map_blocks with itk.resample_image_filter Uses itk_image_to_ngff_image and ngff_image_to_itk_image to bridge between ITK and ngff-zarr data representations. --- ...ple23_NgffZarrMultiscaleRegistration.ipynb | 437 ++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 examples/ITK_Example23_NgffZarrMultiscaleRegistration.ipynb diff --git a/examples/ITK_Example23_NgffZarrMultiscaleRegistration.ipynb b/examples/ITK_Example23_NgffZarrMultiscaleRegistration.ipynb new file mode 100644 index 0000000..4d52843 --- /dev/null +++ b/examples/ITK_Example23_NgffZarrMultiscaleRegistration.ipynb @@ -0,0 +1,437 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-Resolution Registration with ngff-zarr and ITKElastix\n", + "\n", + "This notebook demonstrates a workflow that combines [ngff-zarr](https://github.com/fideus-labs/ngff-zarr) multi-resolution image pyramids with [ITKElastix](https://github.com/InsightSoftwareConsortium/ITKElastix) registration.\n", + "\n", + "The approach:\n", + "\n", + "1. Load two images and convert them to `ngff_zarr.Multiscales` pyramids with `to_multiscales`.\n", + "2. Register the images at a **coarse resolution** for speed (rigid \u2192 affine \u2192 B-spline).\n", + "3. Convert the Elastix result to a standard `itk.CompositeTransform`.\n", + "4. Apply the transform to resample the **full-resolution** image in parallel using `dask.array.map_blocks`.\n", + "\n", + "The `ngff_zarr` functions `itk_image_to_ngff_image` and `ngff_image_to_itk_image` bridge between ITK and ngff-zarr data representations, preserving spatial metadata (spacing, origin)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import itk\n", + "import numpy as np\n", + "import dask.array as da\n", + "import matplotlib.pyplot as plt\n", + "from ngff_zarr import (\n", + " itk_image_to_ngff_image,\n", + " ngff_image_to_itk_image,\n", + " to_multiscales,\n", + ")\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Images\n", + "\n", + "We load the fixed and moving 2D CT head images as ITK float images, which are required by Elastix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fixed_image = itk.imread(\"data/CT_2D_head_fixed.mha\", itk.F)\n", + "moving_image = itk.imread(\"data/CT_2D_head_moving.mha\", itk.F)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=[12, 6])\n", + "axs[0].imshow(fixed_image, cmap=\"gray\")\n", + "axs[0].set_title(\"Fixed\")\n", + "axs[1].imshow(moving_image, cmap=\"gray\")\n", + "axs[1].set_title(\"Moving\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate Multi-Resolution Pyramids with ngff-zarr\n", + "\n", + "`itk_image_to_ngff_image` converts an `itk.Image` to an `NgffImage`, preserving spacing and origin as `scale` and `translation` metadata. The image data is backed by a dask array.\n", + "\n", + "`to_multiscales` then generates a multi-resolution pyramid. Here we use `scale_factors=[2, 4]`, which produces three levels: the original resolution, 2\u00d7 downsampled, and 4\u00d7 downsampled. We will register at the coarsest level for speed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ngff_fixed = itk_image_to_ngff_image(fixed_image)\n", + "ngff_moving = itk_image_to_ngff_image(moving_image)\n", + "\n", + "scale_factors = [2, 4]\n", + "multiscales_fixed = to_multiscales(ngff_fixed, scale_factors=scale_factors)\n", + "multiscales_moving = to_multiscales(ngff_moving, scale_factors=scale_factors)\n", + "\n", + "print(\"Fixed image pyramid:\")\n", + "for i, image in enumerate(multiscales_fixed.images):\n", + " print(f\" Level {i}: shape={image.data.shape}, scale={image.scale}\")\n", + "\n", + "print(\"\\nMoving image pyramid:\")\n", + "for i, image in enumerate(multiscales_moving.images):\n", + " print(f\" Level {i}: shape={image.data.shape}, scale={image.scale}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register at Coarse Resolution\n", + "\n", + "We extract the coarsest resolution level from both pyramids and convert back to `itk.Image` using `ngff_image_to_itk_image`. We pass `wasm=False` to get native `itk.Image` objects, which are required by ITKElastix.\n", + "\n", + "We then run a three-stage Elastix registration (rigid \u2192 affine \u2192 B-spline). Registering at lower resolution is significantly faster. Because the resulting transform is defined in physical coordinates, it can be applied at any resolution level." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract the coarsest resolution level\n", + "fixed_coarse_ngff = multiscales_fixed.images[-1]\n", + "moving_coarse_ngff = multiscales_moving.images[-1]\n", + "\n", + "print(f\"Coarse fixed: shape={fixed_coarse_ngff.data.shape}, scale={fixed_coarse_ngff.scale}\")\n", + "print(f\"Coarse moving: shape={moving_coarse_ngff.data.shape}, scale={moving_coarse_ngff.scale}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert NgffImage back to itk.Image for Elastix\n", + "fixed_coarse = ngff_image_to_itk_image(fixed_coarse_ngff, wasm=False)\n", + "moving_coarse = ngff_image_to_itk_image(moving_coarse_ngff, wasm=False)\n", + "\n", + "# Ensure float32 pixel type as required by Elastix\n", + "ImageType = itk.Image[itk.F, 2]\n", + "if not isinstance(fixed_coarse, ImageType):\n", + " fixed_coarse = itk.cast_image_filter(fixed_coarse, ttype=(type(fixed_coarse), ImageType))\n", + "if not isinstance(moving_coarse, ImageType):\n", + " moving_coarse = itk.cast_image_filter(moving_coarse, ttype=(type(moving_coarse), ImageType))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up multi-stage registration: rigid -> affine -> bspline\n", + "parameter_object = itk.ParameterObject.New()\n", + "parameter_object.AddParameterMap(\n", + " itk.ParameterObject.GetDefaultParameterMap(\"rigid\")\n", + ")\n", + "parameter_object.AddParameterMap(\n", + " itk.ParameterObject.GetDefaultParameterMap(\"affine\")\n", + ")\n", + "parameter_object.AddParameterMap(\n", + " itk.ParameterObject.GetDefaultParameterMap(\"bspline\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run registration using the object-oriented API.\n", + "# We need the registration_method object to later extract the ITK transform.\n", + "registration_method = itk.ElastixRegistrationMethod[\n", + " type(fixed_coarse), type(moving_coarse)\n", + "].New(\n", + " fixed_image=fixed_coarse,\n", + " moving_image=moving_coarse,\n", + " parameter_object=parameter_object,\n", + ")\n", + "registration_method.SetLogToConsole(False)\n", + "registration_method.Update()\n", + "\n", + "result_image_coarse = registration_method.GetOutput()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize the coarse registration result\n", + "fig, axs = plt.subplots(1, 3, figsize=[15, 5])\n", + "axs[0].imshow(fixed_coarse, cmap=\"gray\")\n", + "axs[0].set_title(\"Fixed (Coarse)\")\n", + "axs[1].imshow(moving_coarse, cmap=\"gray\")\n", + "axs[1].set_title(\"Moving (Coarse)\")\n", + "axs[2].imshow(result_image_coarse, cmap=\"gray\")\n", + "axs[2].set_title(\"Registered (Coarse)\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert to ITK Transform\n", + "\n", + "The `ConvertToItkTransform` method converts the Elastix registration results into a standard `itk.CompositeTransform`. This composite transform chains the rigid, affine, and B-spline stages into a single object that can be used with any ITK resampling filter.\n", + "\n", + "Because the transform is defined in **physical coordinates** (not pixel indices), it applies correctly at any resolution level." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "elx_advanced_transform = registration_method.GetCombinationTransform()\n", + "itk_composite_transform = itk.CompositeTransform[itk.D, 2].cast(\n", + " registration_method.ConvertToItkTransform(elx_advanced_transform)\n", + ")\n", + "print(itk_composite_transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Apply Transform at Full Resolution with dask.array.map_blocks\n", + "\n", + "Now we apply the transform obtained at coarse resolution to resample the **full-resolution** moving image. We use `dask.array.map_blocks` to process each chunk of the output image independently and in parallel.\n", + "\n", + "The strategy:\n", + "- The full-resolution fixed image's dask array serves as the **output grid template**. Each chunk defines a spatial region in the output.\n", + "- For each chunk, we use `block_info` to compute its physical origin from the array location and the image's spacing and origin metadata.\n", + "- We create a small ITK reference image for that chunk's spatial region and call `itk.resample_image_filter` to resample the moving image into it.\n", + "\n", + "For simplicity, the full moving image is loaded into memory so that each block function can sample from it. More memory-efficient approaches (e.g., lazy region-based reads) will be available in future ngff-zarr releases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get full-resolution NgffImage data (dask-backed)\n", + "fixed_full = multiscales_fixed.images[0]\n", + "moving_full = multiscales_moving.images[0]\n", + "\n", + "# Materialize the full-resolution moving image as an itk.Image for resampling.\n", + "# Each block function will sample from this image.\n", + "moving_itk_full = ngff_image_to_itk_image(moving_full, wasm=False)\n", + "if not isinstance(moving_itk_full, ImageType):\n", + " moving_itk_full = itk.cast_image_filter(\n", + " moving_itk_full, ttype=(type(moving_itk_full), ImageType)\n", + " )\n", + "\n", + "print(f\"Full-resolution fixed: shape={fixed_full.data.shape}\")\n", + "print(f\"Full-resolution moving: shape={moving_full.data.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def resample_block(block, block_info=None, transform=None, moving_image=None,\n", + " spacing_itk=None, origin_itk=None):\n", + " \"\"\"Resample a single output block from the moving image using the given transform.\n", + "\n", + " Parameters\n", + " ----------\n", + " block : numpy array\n", + " The chunk from the fixed image grid (used only for its shape).\n", + " block_info : dict\n", + " Provided by dask.array.map_blocks. Contains chunk position information.\n", + " transform : itk.Transform\n", + " The spatial transform mapping fixed to moving image space.\n", + " moving_image : itk.Image\n", + " The full-resolution moving image to resample from.\n", + " spacing_itk : list\n", + " Pixel spacing in ITK order [x, y].\n", + " origin_itk : list\n", + " Image origin in ITK order [x, y].\n", + " \"\"\"\n", + " if block_info is None:\n", + " return block\n", + "\n", + " # block_info[0]['array-location'] gives [(y_start, y_stop), (x_start, x_stop)]\n", + " # for a 2D array with dims (y, x)\n", + " array_location = block_info[0][\"array-location\"]\n", + "\n", + " # Compute the physical origin for this block.\n", + " # array-location is in (y, x) order; ITK origin is in (x, y) order.\n", + " block_origin_itk = [\n", + " origin_itk[0] + array_location[1][0] * spacing_itk[0], # x\n", + " origin_itk[1] + array_location[0][0] * spacing_itk[1], # y\n", + " ]\n", + "\n", + " block_shape = block.shape # (rows, cols) = (y_size, x_size)\n", + "\n", + " # Create a reference image for this block's spatial region\n", + " reference = itk.Image[itk.F, 2].New()\n", + " region = itk.ImageRegion[2]()\n", + " region.SetSize([int(block_shape[1]), int(block_shape[0])]) # ITK size is [x, y]\n", + " region.SetIndex([0, 0])\n", + " reference.SetRegions(region)\n", + " reference.SetSpacing(spacing_itk)\n", + " reference.SetOrigin(block_origin_itk)\n", + " reference.Allocate()\n", + "\n", + " # Resample the moving image into this block's region\n", + " resampled = itk.resample_image_filter(\n", + " moving_image,\n", + " transform=transform,\n", + " use_reference_image=True,\n", + " reference_image=reference,\n", + " )\n", + "\n", + " return itk.array_from_image(resampled)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract spacing and origin from the full-resolution fixed image.\n", + "# NgffImage scale/translation use dim-name keys in (y, x) order.\n", + "# ITK uses (x, y) order.\n", + "spacing_itk = [fixed_full.scale[\"x\"], fixed_full.scale[\"y\"]]\n", + "origin_itk = [fixed_full.translation[\"x\"], fixed_full.translation[\"y\"]]\n", + "\n", + "# Apply the transform in parallel across all chunks\n", + "resampled_dask = da.map_blocks(\n", + " resample_block,\n", + " fixed_full.data,\n", + " dtype=np.float32,\n", + " transform=itk_composite_transform,\n", + " moving_image=moving_itk_full,\n", + " spacing_itk=spacing_itk,\n", + " origin_itk=origin_itk,\n", + ")\n", + "\n", + "print(f\"Resampled dask array: shape={resampled_dask.shape}, chunks={resampled_dask.chunks}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute the result (triggers parallel execution across chunks)\n", + "resampled_array = resampled_dask.compute()\n", + "print(f\"Resampled result shape: {resampled_array.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Results\n", + "\n", + "Compare the full-resolution resampled result with the fixed and moving images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fixed_array = np.asarray(fixed_full.data)\n", + "\n", + "fig, axs = plt.subplots(1, 4, figsize=[20, 5])\n", + "axs[0].imshow(fixed_array, cmap=\"gray\")\n", + "axs[0].set_title(\"Fixed (Full Resolution)\")\n", + "axs[1].imshow(np.asarray(moving_full.data), cmap=\"gray\")\n", + "axs[1].set_title(\"Moving (Full Resolution)\")\n", + "axs[2].imshow(resampled_array, cmap=\"gray\")\n", + "axs[2].set_title(\"Resampled Moving\")\n", + "diff = fixed_array - resampled_array\n", + "im = axs[3].imshow(diff, cmap=\"RdBu\")\n", + "axs[3].set_title(\"Difference\")\n", + "fig.colorbar(im, ax=axs[3], orientation=\"vertical\", fraction=0.046, pad=0.04)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This example demonstrated a multi-resolution registration and resampling workflow:\n", + "\n", + "- **`ngff_zarr.itk_image_to_ngff_image`** and **`ngff_zarr.ngff_image_to_itk_image`** bridge between ITK and ngff-zarr data representations, preserving spatial metadata.\n", + "- **`ngff_zarr.to_multiscales`** generates multi-resolution pyramids from a single image with one call.\n", + "- **Coarse-resolution registration** with ITKElastix (rigid \u2192 affine \u2192 B-spline) is fast, and because the resulting transform is defined in physical coordinates, it applies at any resolution.\n", + "- **`dask.array.map_blocks`** enables parallel resampling at full resolution, processing each output chunk independently with `itk.resample_image_filter`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}