diff --git a/.github/workflows/build-notebooks.yaml b/.github/workflows/build-notebooks.yaml new file mode 100644 index 0000000..08ba001 --- /dev/null +++ b/.github/workflows/build-notebooks.yaml @@ -0,0 +1,45 @@ +name: Build Notebooks +on: + push: + +jobs: + run: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install -U pip + python -m pip install jupytext nbconvert + + - name: Build notebooks + run: | + jupyter nbconvert 01_CARE/solution.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output exercise.ipynb + jupyter nbconvert 02_Noise2Void/solution.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output exercise.ipynb + jupyter nbconvert 03_COSDD/solution.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output exercise.ipynb + jupyter nbconvert 03_COSDD/bonus-solution-generation.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output bonus-exercise.ipynb + jupyter nbconvert 04_DenoiSplit/solution.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output exercise.ipynb + jupyter nbconvert 05_bonus_Noise2Noise/n2n_solution.ipynb \ + --ClearOutputPreprocessor.enabled=True --TagRemovePreprocessor.enabled=True --TagRemovePreprocessor.remove_cell_tags solution --to notebook \ + --output n2n_exercise.ipynb + + - uses: EndBug/add-and-commit@v9 + with: + add: "*.ipynb" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..20ae35f --- /dev/null +++ b/.gitignore @@ -0,0 +1,177 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +*.ckpt +01_CARE/runs/ +02_Noise2Void/checkpoints/ +02_Noise2Void/logs/ +03_COSDD/data/ +03_COSDD/checkpoints/ +03_COSDD/COSDD +04_bonus_denoiSplit/tensorboard_logs/ +04_bonus_Noise2Noise/logs/ +04_DenoiSplit/denoisplit/ +04_DenoiSplit/tensorboard_logs/ +04_DenoiSplit/lightning_logs/ +05_bonus_Noise2Noise/logs/ +data/ diff --git a/01_CARE/exercise.ipynb b/01_CARE/exercise.ipynb new file mode 100644 index 0000000..026750e --- /dev/null +++ b/01_CARE/exercise.ipynb @@ -0,0 +1,1057 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Content-aware image restoration\n", + "\n", + "Fluorescence microscopy is constrained by the microscope's optics, fluorophore chemistry, and the sample's photon tolerance. These constraints require balancing imaging speed, resolution, light exposure, and depth. CARE demonstrates how Deep learning can extend the range of biological phenomena observable by microscopy when any of these factor becomes limiting.\n", + "\n", + "**Reference**: Weigert, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097. doi:[10.1038/s41592-018-0216-7](https://www.nature.com/articles/s41592-018-0216-7)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CARE\n", + "\n", + "In this first exercise we will train a CARE model for a 2D denoising task. CARE stands for Content-Aware image REstoration, and is a supervised method in which we use pairs of degraded and high quality image to train a particular task. The original paper demonstrated improvement of image quality on a variety of tasks such as image restoration or resolution improvement. Here, we will apply CARE to denoise images acquired at low laser power in order to recover the biological structures present in the data!\n", + "\n", + "

\n", + " \"Denoising \n", + "

\n", + "\n", + "We'll use the UNet model that we built in the semantic segmentation exercise and use a different set of functions to train the model for restoration rather than segmentation.\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Train a UNet on a new task!\n", + "- Understand how to train CARE\n", + " \n", + "
\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tifffile\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import Union, List, Tuple\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import torch.nn\n", + "import torch.optim\n", + "from torch import no_grad, cuda\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from datetime import datetime\n", + "from dlmbl_unet import UNet\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "
\n", + "\n", + "## Part 1: Set-up the data\n", + "\n", + "CARE is a fully supervised algorithm, therefore we need image pairs for training. In practice this is best achieved by acquiring each image twice, once with short exposure time or low laser power to obtain a noisy low-SNR (signal-to-noise ratio) image, and once with high SNR.\n", + "\n", + "Here, we will be using high SNR images of Human U2OS cells taken from the Broad Bioimage Benchmark Collection ([BBBC006v1](https://bbbc.broadinstitute.org/BBBC006)). The low SNR images were created by synthetically adding strong read-out and shot noise, and applying pixel binning of 2x2, thus mimicking acquisitions at a very low light level.\n", + "\n", + "Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so they would often need to be co-registered before training a CARE model." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split the dataset into training and validation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train\" / \"low\"\n", + "train_targets_path = root_path / \"train\" / \"GT\"\n", + "test_image_path = root_path / \"test\" / \"low\"\n", + "test_target_path = root_path / \"test\" / \"GT\"\n", + "\n", + "\n", + "image_files = list(Path(train_images_path).rglob(\"*.tif\"))\n", + "target_files = list(Path(train_targets_path).rglob(\"*.tif\"))\n", + "assert len(image_files) == len(\n", + " target_files\n", + "), \"Number of images and targets do not match\"\n", + "\n", + "print(f\"Total size of train dataset: {len(image_files)}\")\n", + "\n", + "# Split the train data into train and validation\n", + "seed = 42\n", + "train_files_percentage = 0.8\n", + "np.random.seed(seed)\n", + "shuffled_indices = np.random.permutation(len(image_files))\n", + "image_files = np.array(image_files)[shuffled_indices]\n", + "target_files = np.array(target_files)[shuffled_indices]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(image_files, target_files)]\n", + "), \"Files do not match\"\n", + "\n", + "train_image_files = image_files[: int(train_files_percentage * len(image_files))]\n", + "train_target_files = target_files[: int(train_files_percentage * len(target_files))]\n", + "val_image_files = image_files[int(train_files_percentage * len(image_files)) :]\n", + "val_target_files = target_files[int(train_files_percentage * len(target_files)) :]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(train_image_files, train_target_files)]\n", + "), \"Train files do not match\"\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(val_image_files, val_target_files)]\n", + "), \"Val files do not match\"\n", + "\n", + "print(f\"Train dataset size: {len(train_image_files)}\")\n", + "print(f\"Validation dataset size: {len(val_image_files)}\")\n", + "\n", + "# Read the test files\n", + "test_image_files = list(test_image_path.rglob(\"*.tif\"))\n", + "test_target_files = list(test_target_path.rglob(\"*.tif\"))\n", + "print(f\"Number of test files: {len(test_image_files)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Patching function\n", + "\n", + "In the majority of cases microscopy images are too large to be processed at once and need to be divided into smaller patches. We will define a function that takes image and target arrays and extract random (paired) patches from them.\n", + "\n", + "The method is a bit scary because accessing the whole patch coordinates requires some magical python expressions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_patches(\n", + " image_array: np.ndarray,\n", + " target_array: np.ndarray,\n", + " patch_size: Union[List[int], Tuple[int, ...]],\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"\n", + " Create random patches from an array and a target.\n", + "\n", + " The method calculates how many patches the image can be divided into and then\n", + " extracts an equal number of random patches.\n", + "\n", + " Important: the images should have an extra dimension before the spatial dimensions.\n", + " if you try it with only 2D or 3D images, don't forget to add an extra dimension\n", + " using `image = image[np.newaxis, ...]`\n", + " \"\"\"\n", + " # random generator\n", + " rng = np.random.default_rng()\n", + " image_patches = []\n", + " target_patches = []\n", + "\n", + " # iterate over the number of samples in the input array\n", + " for s in range(image_array.shape[0]):\n", + " # calculate the number of patches we can extract\n", + " sample = image_array[s]\n", + " target_sample = target_array[s]\n", + " n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)\n", + "\n", + " # iterate over the number of patches\n", + " for _ in range(n_patches):\n", + " # get random coordinates for the patch and create the crop coordinates\n", + " crop_coords = [\n", + " rng.integers(0, sample.shape[i] - patch_size[i], endpoint=True)\n", + " for i in range(len(patch_size))\n", + " ]\n", + "\n", + " # extract patch from the data\n", + " patch = (\n", + " sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # same for the target patch\n", + " target_patch = (\n", + " target_sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # add the patch pair to the list\n", + " image_patches.append(patch)\n", + " target_patches.append(target_patch)\n", + "\n", + " # return stack of patches\n", + " return np.stack(image_patches), np.stack(target_patches)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create patches\n", + "\n", + "To train the network, we will use patches of size 128x128. We first need to load the data, stack it and then call our patching function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images and stack them into arrays\n", + "train_images_array = np.stack([tifffile.imread(str(f)) for f in train_image_files])\n", + "train_targets_array = np.stack([tifffile.imread(str(f)) for f in train_target_files])\n", + "val_images_array = np.stack([tifffile.imread(str(f)) for f in val_image_files])\n", + "val_targets_array = np.stack([tifffile.imread(str(f)) for f in val_target_files])\n", + "\n", + "test_images_array = np.stack([tifffile.imread(str(f)) for f in test_image_files])\n", + "test_targets_array = np.stack([tifffile.imread(str(f)) for f in test_target_files])\n", + "\n", + "\n", + "print(f\"Train images array shape: {train_images_array.shape}\")\n", + "print(f\"Validation images array shape: {val_images_array.shape}\")\n", + "print(f\"Test array shape: {test_images_array.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create patches\n", + "patch_size = (128, 128)\n", + "\n", + "train_images_patches, train_targets_patches = create_patches(\n", + " train_images_array, train_targets_array, patch_size\n", + ")\n", + "assert (\n", + " train_images_patches.shape[0] == train_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "val_images_patches, val_targets_patches = create_patches(\n", + " val_images_array, val_targets_array, patch_size\n", + ")\n", + "assert (\n", + " val_images_patches.shape[0] == val_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "print(f\"Train images patches shape: {train_images_patches.shape}\")\n", + "print(f\"Validation images patches shape: {val_images_patches.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(train_images_patches[0], cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Train image\")\n", + "ax[0, 1].imshow(train_targets_patches[0], cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Train target\")\n", + "ax[1, 0].imshow(train_images_patches[1], cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Train image\")\n", + "ax[1, 1].imshow(train_targets_patches[1], cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Train target\")\n", + "ax[2, 0].imshow(train_images_patches[2], cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Train image\")\n", + "ax[2, 1].imshow(train_targets_patches[2], cmap=\"magma\")\n", + "ax[2, 1].set_title(\"Train target\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Dataset class\n", + "\n", + "In modern deep learning libraries, the data is often wrapped into a class called a `Dataset`. Instances of that class are then used to extract the patches before feeding them to the network.\n", + "\n", + "Here, the class will be wrapped around our pre-computed stacks of patches. Our `CAREDataset` class is built on top of the PyTorch `Dataset` class (we say it \"inherits\" from `Dataset`, the \"parent\" class). That means that it has some function hidden from us that are defined in the PyTorch repository, but that we also need to implement specific pre-defined methods, such as `__len__` and `__getitem__`. The advantage is that PyTorch knows what to do with a `Dataset` \"child\" class, since its behaviour is defined in the `Dataset` class, but we can do operations that are closely related to our own data in the method we implement." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Question: Normalization

\n", + "\n", + "In the following cell we calculate the mean and standard deviation of the input and target images so that we can normalize them.\n", + "Why is normalization important? \n", + "Should we normalize the input and ground truth data the same way? \n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Calculate the mean and std of the train dataset\n", + "train_mean = train_images_array.mean()\n", + "train_std = train_images_array.std()\n", + "target_mean = train_targets_array.mean()\n", + "target_std = train_targets_array.std()\n", + "print(f\"Train mean: {train_mean}, std: {train_std}\")\n", + "print(f\"Target mean: {target_mean}, std: {target_std}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These functions will be used to normalize the data and perform data augmentation as it is loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def normalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Normalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Normalized array.\n", + " \"\"\"\n", + " return (image - mean) / std\n", + "\n", + "\n", + "def _flip_and_rotate(\n", + " image: np.ndarray, rotate_state: int, flip_state: int\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Apply the given number of 90 degrees rotations and flip to an array.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " rotate_state : int\n", + " Number of 90 degree rotations to apply.\n", + " flip_state : int\n", + " 0 or 1, whether to flip the array or not.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Flipped and rotated array.\n", + " \"\"\"\n", + " rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))\n", + " flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated\n", + " return flipped.copy()\n", + "\n", + "\n", + "def augment_batch(\n", + " patch: np.ndarray,\n", + " target: np.ndarray,\n", + " seed: int = 42,\n", + ") -> Tuple[np.ndarray, ...]:\n", + " \"\"\"\n", + " Apply augmentation function to patches and masks.\n", + "\n", + " Parameters\n", + " ----------\n", + " patch : np.ndarray\n", + " Array containing single image or patch, 2D or 3D with masked pixels.\n", + " original_image : np.ndarray\n", + " Array containing original image or patch, 2D or 3D.\n", + " mask : np.ndarray\n", + " Array containing only masked pixels, 2D or 3D.\n", + " seed : int, optional\n", + " Seed for random number generator, controls the rotation and flipping.\n", + "\n", + " Returns\n", + " -------\n", + " Tuple[np.ndarray, ...]\n", + " Tuple of augmented arrays.\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed=seed)\n", + " rotate_state = rng.integers(0, 4)\n", + " flip_state = rng.integers(0, 2)\n", + " return (\n", + " _flip_and_rotate(patch, rotate_state, flip_state),\n", + " _flip_and_rotate(target, rotate_state, flip_state),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Dataset\n", + "\n", + "Here we're defining the basic pytorch dataset class that will be used to load the data. This class will be used to load the data and apply the normalization and augmentation functions to the data as it is loaded.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Define a Dataset\n", + "class CAREDataset(Dataset): # CAREDataset inherits from the PyTorch Dataset class\n", + " def __init__(\n", + " self, image_data: np.ndarray, target_data: np.ndarray, apply_augmentations=False\n", + " ):\n", + " # these are the \"members\" of the CAREDataset\n", + " self.image_data = image_data\n", + " self.target_data = target_data\n", + " self.patch_augment = apply_augmentations\n", + "\n", + " def __len__(self):\n", + " \"\"\"Return the total number of patches.\n", + "\n", + " This method is called when applying `len(...)` to an instance of our class\n", + " \"\"\"\n", + " return self.image_data.shape[\n", + " 0\n", + " ] # Your code here, define the total number of patches\n", + "\n", + " def __getitem__(self, index):\n", + " \"\"\"Return a single pair of patches.\"\"\"\n", + "\n", + " # get patch\n", + " patch = self.image_data[index]\n", + "\n", + " # get target\n", + " target = self.target_data[index]\n", + "\n", + " # Apply transforms\n", + " if self.patch_augment:\n", + " patch, target = augment_batch(patch=patch, target=target)\n", + "\n", + " # Normalize the patch\n", + " patch = normalize(patch, train_mean, train_std)\n", + " target = normalize(target, target_mean, target_std)\n", + "\n", + " return patch[np.newaxis].astype(np.float32), target[np.newaxis].astype(\n", + " np.float32\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test the dataset\n", + "train_dataset = CAREDataset(\n", + " image_data=train_images_patches, target_data=train_targets_patches\n", + ")\n", + "val_dataset = CAREDataset(\n", + " image_data=val_images_patches, target_data=val_targets_patches\n", + ")\n", + "\n", + "# what is the dataset length?\n", + "assert len(train_dataset) == train_images_patches.shape[0], \"Dataset length is wrong\"\n", + "\n", + "# check the normalization\n", + "assert train_dataset[42][0].max() <= 10, \"Patch isn't normalized properly\"\n", + "assert train_dataset[42][1].max() <= 10, \"Target patch isn't normalized properly\"\n", + "\n", + "# check the get_item function\n", + "assert train_dataset[42][0].shape == (1, *patch_size), \"Patch size is wrong\"\n", + "assert train_dataset[42][1].shape == (1, *patch_size), \"Target patch size is wrong\"\n", + "assert train_dataset[42][0].dtype == np.float32, \"Patch dtype is wrong\"\n", + "assert train_dataset[42][1].dtype == np.float32, \"Target patch dtype is wrong\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training and validation data are stored as an instance of a `Dataset`. \n", + "This describes how each image should be loaded.\n", + "Now we will prepare them to be fed into the model with a `Dataloader`.\n", + "\n", + "This will use the Dataset to load individual images and organise them into batches.\n", + "The Dataloader will shuffle the data at the start of each epoch, outputting different random batches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the dataset and create a DataLoader\n", + "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Data

\n", + "\n", + "In this section, we prepared paired training data. \n", + "The steps were:\n", + "1) Loading the images.\n", + "2) Cropping them into patches.\n", + "3) Checking the patches visually.\n", + "4) Creating an instance of a pytorch dataset and dataloader.\n", + "\n", + "You'll see a similar preparation procedure followed for most deep learning vision tasks.\n", + "\n", + "Next, we'll use this data to train a denoising model.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 2: Training the model\n", + "\n", + "Image restoration task is very similar to the semantic segmentation task we have done in the previous exercise. We can use the same UNet model and just need to adapt a few things.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](nb_data/carenet.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate the model\n", + "\n", + "We'll be using the model from the previous exercise, so we need to load the relevant module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model\n", + "model = UNet(depth=2, in_channels=1, out_channels=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 1: Loss function

\n", + "\n", + "CARE trains image to image, therefore we need a different loss function compared to the segmentation task (image to mask). Can you think of a suitable loss function?\n", + "\n", + "*hint: look in the `torch.nn` module of PyTorch ([link](https://pytorch.org/docs/stable/nn.html#loss-functions)).*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "loss = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 2: Optimizer

\n", + "\n", + "Similarly, define the optimizer. No need to be too inventive here!\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "optimizer = #### YOUR CODE HERE ####" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Here we will train a CARE model using classes and functions you defined in the previous tasks.\n", + "We're using the same training loop as in the semantic segmentation exercise.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 3: Tensorboard

\n", + "\n", + "We'll monitor the training of all models in 05_image_restoration using Tensorboard. \n", + "This is a program that plots the training and validation loss of networks as they train, and can also show input/output image pairs.\n", + "Follow these steps to launch Tensorboard.\n", + "\n", + "1) Open the extensions panel in VS Code. Look for this icon. \n", + "\n", + "![image](nb_data/extensions.png)\n", + "\n", + "2) Search Tensorboard and install and install the extension published by Microsoft.\n", + "3) Start training. Run the cell below to begin training the model and generating logs.\n", + "3) Once training is started. Open the command palette (ctrl+shift+p), search for Python: Launch Tensorboard and hit enter.\n", + "4) When prompted, select \"Select another folder\" and enter the path to the `01_CARE/runs/` directory.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In tensorboard, click the SCALARS tab to see the training and validation loss curves. \n", + "At the end of each epoch, refresh Tensorboard using the button in the top right to see the latest loss.\n", + "\n", + "Click the IMAGES tab to see the noisy inputs, denoised outputs and clean targets.\n", + "These are updated at the end of each epoch too." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "n_epochs = 5\n", + "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", + "# tensorboard\n", + "tb_logger = SummaryWriter(\"runs/Unet\"+datetime.now().strftime('%d%H-%M%S'))\n", + "def log_image(image, tag, logger, step):\n", + " normalised_image = image.cpu().numpy()\n", + " normalised_image = normalised_image - np.percentile(normalised_image, 1)\n", + " normalised_image = normalised_image / np.percentile(normalised_image, 99)\n", + " normalised_image = np.clip(normalised_image, 0, 1)\n", + " logger.add_images(tag=tag, img_tensor=normalised_image, global_step=step)\n", + "\n", + "\n", + "train_losses = []\n", + "val_losses = []\n", + "\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " for i, (image_batch, target_batch) in enumerate(train_dataloader):\n", + " batch = image_batch.to(device)\n", + " target = target_batch.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = model(batch)\n", + " train_loss = loss(output, target)\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " if i % 10 == 0:\n", + " print(f\"Epoch: {epoch}, Batch: {i}, Loss: {train_loss.item()}\")\n", + "\n", + " model.eval()\n", + "\n", + " with no_grad():\n", + " val_loss = 0\n", + " for i, (batch, target) in enumerate(val_dataloader):\n", + " batch = batch.to(device)\n", + " target = target.to(device)\n", + "\n", + " output = model(batch)\n", + " val_loss = loss(output, target)\n", + "\n", + " # log tensorboard\n", + " step = epoch * len(train_dataloader)\n", + " tb_logger.add_scalar(tag=\"train_loss\", scalar_value=train_loss, global_step=step)\n", + " tb_logger.add_scalar(tag=\"val_loss\", scalar_value=val_loss, global_step=step)\n", + "\n", + " # we always log the last validation images\n", + " log_image(batch, tag=\"val_input\", logger=tb_logger, step=step)\n", + " log_image(target, tag=\"val_target\", logger=tb_logger, step=step)\n", + " log_image(output, tag=\"val_prediction\", logger=tb_logger, step=step)\n", + "\n", + " print(f\"Validation loss: {val_loss.item()}\")\n", + "\n", + " # Save the losses for plotting\n", + " train_losses.append(train_loss.item())\n", + " val_losses.append(val_loss.item())\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot the loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot training and validation losses\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(train_losses)\n", + "plt.plot(val_losses)\n", + "plt.xlabel(\"Iterations\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.legend([\"Train loss\", \"Validation loss\"])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 2: Training

\n", + "\n", + "In this section, we created and trained a UNet for denoising.\n", + "We:\n", + "1) Instantiated the model with random weights.\n", + "2) Chose a loss function to compare the output image to the ground truth clean image.\n", + "3) Chose an optimizer to minimize that loss function.\n", + "4) Trained the model with this optimizer.\n", + "5) Examined the training and validation loss curves to see how well our model trained.\n", + "\n", + "Next, we'll load a test set of noisy images and see how well our model denoises them.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 3: Predicting on the test dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the dataset for the test data\n", + "test_dataset = CAREDataset(\n", + " image_data=test_images_array, target_data=test_targets_array\n", + ")\n", + "test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 4: Predict using the correct mean/std

\n", + "\n", + "In Part 1 we normalized the inputs and the targets before feeding them into the model. This means that the model will output normalized clean images, but we'd like them to be on the same scale as the real clean images.\n", + "\n", + "Recall the variables we used to normalize the data in Part 1, and use them denormalize the output of the model.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Denormalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Denormalized array.\n", + " \"\"\"\n", + " return image * std + mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), #### YOUR CODE HERE ####))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(test_images_array[0].squeeze(), cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Test image\")\n", + "ax[0, 1].imshow(predictions[0][0].squeeze(), cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Prediction\")\n", + "ax[1, 0].imshow(test_images_array[1].squeeze(), cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Test image\")\n", + "ax[1, 1].imshow(predictions[1][0].squeeze(), cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Prediction\")\n", + "ax[2, 0].imshow(test_images_array[2].squeeze(), cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Test image\")\n", + "ax[2, 1].imshow(predictions[2][0].squeeze(), cmap=\"magma\")\n", + "plt.tight_layout()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Predicting

\n", + "\n", + "In this section, we evaluated the performance of our denoiser.\n", + "We:\n", + "1) Created a CAREDataset and Dataloader for a prediction loop.\n", + "2) Ran a prediction loop on the test data.\n", + "3) Examined the outputs.\n", + "\n", + "This notebook has shown how matched pairs of noisy and clean images can train a UNet to denoise, but what if we don't have any clean images? In the next notebook, we'll try Noise2Void, a method for training a UNet to denoise with only noisy images.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/nb_material/carenet.png b/01_CARE/nb_data/carenet.png old mode 100644 new mode 100755 similarity index 100% rename from nb_material/carenet.png rename to 01_CARE/nb_data/carenet.png diff --git a/nb_material/denoising_binning_overview.png b/01_CARE/nb_data/denoising_binning_overview.png old mode 100644 new mode 100755 similarity index 100% rename from nb_material/denoising_binning_overview.png rename to 01_CARE/nb_data/denoising_binning_overview.png diff --git a/01_CARE/nb_data/extensions.png b/01_CARE/nb_data/extensions.png new file mode 100644 index 0000000..7c14814 Binary files /dev/null and b/01_CARE/nb_data/extensions.png differ diff --git a/nb_material/img_intro.png b/01_CARE/nb_data/img_intro.png old mode 100644 new mode 100755 similarity index 100% rename from nb_material/img_intro.png rename to 01_CARE/nb_data/img_intro.png diff --git a/01_CARE/nb_data/tradeoff.png b/01_CARE/nb_data/tradeoff.png new file mode 100755 index 0000000..3a31123 Binary files /dev/null and b/01_CARE/nb_data/tradeoff.png differ diff --git a/01_CARE/solution.ipynb b/01_CARE/solution.ipynb new file mode 100755 index 0000000..8548aaf --- /dev/null +++ b/01_CARE/solution.ipynb @@ -0,0 +1,1130 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Content-aware image restoration\n", + "\n", + "Fluorescence microscopy is constrained by the microscope's optics, fluorophore chemistry, and the sample's photon tolerance. These constraints require balancing imaging speed, resolution, light exposure, and depth. CARE demonstrates how Deep learning can extend the range of biological phenomena observable by microscopy when any of these factor becomes limiting.\n", + "\n", + "**Reference**: Weigert, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097. doi:[10.1038/s41592-018-0216-7](https://www.nature.com/articles/s41592-018-0216-7)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CARE\n", + "\n", + "In this first exercise we will train a CARE model for a 2D denoising task. CARE stands for Content-Aware image REstoration, and is a supervised method in which we use pairs of degraded and high quality image to train a particular task. The original paper demonstrated improvement of image quality on a variety of tasks such as image restoration or resolution improvement. Here, we will apply CARE to denoise images acquired at low laser power in order to recover the biological structures present in the data!\n", + "\n", + "

\n", + " \"Denoising \n", + "

\n", + "\n", + "We'll use the UNet model that we built in the semantic segmentation exercise and use a different set of functions to train the model for restoration rather than segmentation.\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Train a UNet on a new task!\n", + "- Understand how to train CARE\n", + " \n", + "
\n", + "\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + " Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tifffile\n", + "import numpy as np\n", + "from pathlib import Path\n", + "from typing import Union, List, Tuple\n", + "from torch.utils.data import Dataset, DataLoader\n", + "import torch.nn\n", + "import torch.optim\n", + "from torch import no_grad, cuda\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "from datetime import datetime\n", + "from dlmbl_unet import UNet\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "
\n", + "\n", + "## Part 1: Set-up the data\n", + "\n", + "CARE is a fully supervised algorithm, therefore we need image pairs for training. In practice this is best achieved by acquiring each image twice, once with short exposure time or low laser power to obtain a noisy low-SNR (signal-to-noise ratio) image, and once with high SNR.\n", + "\n", + "Here, we will be using high SNR images of Human U2OS cells taken from the Broad Bioimage Benchmark Collection ([BBBC006v1](https://bbbc.broadinstitute.org/BBBC006)). The low SNR images were created by synthetically adding strong read-out and shot noise, and applying pixel binning of 2x2, thus mimicking acquisitions at a very low light level.\n", + "\n", + "Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so they would often need to be co-registered before training a CARE model." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split the dataset into training and validation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-CARE_U2OS.unzip\" / \"data\" / \"U2OS\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train\" / \"low\"\n", + "train_targets_path = root_path / \"train\" / \"GT\"\n", + "test_image_path = root_path / \"test\" / \"low\"\n", + "test_target_path = root_path / \"test\" / \"GT\"\n", + "\n", + "\n", + "image_files = list(Path(train_images_path).rglob(\"*.tif\"))\n", + "target_files = list(Path(train_targets_path).rglob(\"*.tif\"))\n", + "assert len(image_files) == len(\n", + " target_files\n", + "), \"Number of images and targets do not match\"\n", + "\n", + "print(f\"Total size of train dataset: {len(image_files)}\")\n", + "\n", + "# Split the train data into train and validation\n", + "seed = 42\n", + "train_files_percentage = 0.8\n", + "np.random.seed(seed)\n", + "shuffled_indices = np.random.permutation(len(image_files))\n", + "image_files = np.array(image_files)[shuffled_indices]\n", + "target_files = np.array(target_files)[shuffled_indices]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(image_files, target_files)]\n", + "), \"Files do not match\"\n", + "\n", + "train_image_files = image_files[: int(train_files_percentage * len(image_files))]\n", + "train_target_files = target_files[: int(train_files_percentage * len(target_files))]\n", + "val_image_files = image_files[int(train_files_percentage * len(image_files)) :]\n", + "val_target_files = target_files[int(train_files_percentage * len(target_files)) :]\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(train_image_files, train_target_files)]\n", + "), \"Train files do not match\"\n", + "assert all(\n", + " [i.name == j.name for i, j in zip(val_image_files, val_target_files)]\n", + "), \"Val files do not match\"\n", + "\n", + "print(f\"Train dataset size: {len(train_image_files)}\")\n", + "print(f\"Validation dataset size: {len(val_image_files)}\")\n", + "\n", + "# Read the test files\n", + "test_image_files = list(test_image_path.rglob(\"*.tif\"))\n", + "test_target_files = list(test_target_path.rglob(\"*.tif\"))\n", + "print(f\"Number of test files: {len(test_image_files)}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Patching function\n", + "\n", + "In the majority of cases microscopy images are too large to be processed at once and need to be divided into smaller patches. We will define a function that takes image and target arrays and extract random (paired) patches from them.\n", + "\n", + "The method is a bit scary because accessing the whole patch coordinates requires some magical python expressions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_patches(\n", + " image_array: np.ndarray,\n", + " target_array: np.ndarray,\n", + " patch_size: Union[List[int], Tuple[int, ...]],\n", + ") -> Tuple[np.ndarray, np.ndarray]:\n", + " \"\"\"\n", + " Create random patches from an array and a target.\n", + "\n", + " The method calculates how many patches the image can be divided into and then\n", + " extracts an equal number of random patches.\n", + "\n", + " Important: the images should have an extra dimension before the spatial dimensions.\n", + " if you try it with only 2D or 3D images, don't forget to add an extra dimension\n", + " using `image = image[np.newaxis, ...]`\n", + " \"\"\"\n", + " # random generator\n", + " rng = np.random.default_rng()\n", + " image_patches = []\n", + " target_patches = []\n", + "\n", + " # iterate over the number of samples in the input array\n", + " for s in range(image_array.shape[0]):\n", + " # calculate the number of patches we can extract\n", + " sample = image_array[s]\n", + " target_sample = target_array[s]\n", + " n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)\n", + "\n", + " # iterate over the number of patches\n", + " for _ in range(n_patches):\n", + " # get random coordinates for the patch and create the crop coordinates\n", + " crop_coords = [\n", + " rng.integers(0, sample.shape[i] - patch_size[i], endpoint=True)\n", + " for i in range(len(patch_size))\n", + " ]\n", + "\n", + " # extract patch from the data\n", + " patch = (\n", + " sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # same for the target patch\n", + " target_patch = (\n", + " target_sample[\n", + " (\n", + " ...,\n", + " *[\n", + " slice(c, c + patch_size[i])\n", + " for i, c in enumerate(crop_coords)\n", + " ],\n", + " )\n", + " ]\n", + " .copy()\n", + " .astype(np.float32)\n", + " )\n", + "\n", + " # add the patch pair to the list\n", + " image_patches.append(patch)\n", + " target_patches.append(target_patch)\n", + "\n", + " # return stack of patches\n", + " return np.stack(image_patches), np.stack(target_patches)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create patches\n", + "\n", + "To train the network, we will use patches of size 128x128. We first need to load the data, stack it and then call our patching function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images and stack them into arrays\n", + "train_images_array = np.stack([tifffile.imread(str(f)) for f in train_image_files])\n", + "train_targets_array = np.stack([tifffile.imread(str(f)) for f in train_target_files])\n", + "val_images_array = np.stack([tifffile.imread(str(f)) for f in val_image_files])\n", + "val_targets_array = np.stack([tifffile.imread(str(f)) for f in val_target_files])\n", + "\n", + "test_images_array = np.stack([tifffile.imread(str(f)) for f in test_image_files])\n", + "test_targets_array = np.stack([tifffile.imread(str(f)) for f in test_target_files])\n", + "\n", + "\n", + "print(f\"Train images array shape: {train_images_array.shape}\")\n", + "print(f\"Validation images array shape: {val_images_array.shape}\")\n", + "print(f\"Test array shape: {test_images_array.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create patches\n", + "patch_size = (128, 128)\n", + "\n", + "train_images_patches, train_targets_patches = create_patches(\n", + " train_images_array, train_targets_array, patch_size\n", + ")\n", + "assert (\n", + " train_images_patches.shape[0] == train_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "val_images_patches, val_targets_patches = create_patches(\n", + " val_images_array, val_targets_array, patch_size\n", + ")\n", + "assert (\n", + " val_images_patches.shape[0] == val_targets_patches.shape[0]\n", + "), \"Number of patches do not match\"\n", + "\n", + "print(f\"Train images patches shape: {train_images_patches.shape}\")\n", + "print(f\"Validation images patches shape: {val_images_patches.shape}\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training patches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(train_images_patches[0], cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Train image\")\n", + "ax[0, 1].imshow(train_targets_patches[0], cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Train target\")\n", + "ax[1, 0].imshow(train_images_patches[1], cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Train image\")\n", + "ax[1, 1].imshow(train_targets_patches[1], cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Train target\")\n", + "ax[2, 0].imshow(train_images_patches[2], cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Train image\")\n", + "ax[2, 1].imshow(train_targets_patches[2], cmap=\"magma\")\n", + "ax[2, 1].set_title(\"Train target\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Dataset class\n", + "\n", + "In modern deep learning libraries, the data is often wrapped into a class called a `Dataset`. Instances of that class are then used to extract the patches before feeding them to the network.\n", + "\n", + "Here, the class will be wrapped around our pre-computed stacks of patches. Our `CAREDataset` class is built on top of the PyTorch `Dataset` class (we say it \"inherits\" from `Dataset`, the \"parent\" class). That means that it has some function hidden from us that are defined in the PyTorch repository, but that we also need to implement specific pre-defined methods, such as `__len__` and `__getitem__`. The advantage is that PyTorch knows what to do with a `Dataset` \"child\" class, since its behaviour is defined in the `Dataset` class, but we can do operations that are closely related to our own data in the method we implement." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Question: Normalization

\n", + "\n", + "In the following cell we calculate the mean and standard deviation of the input and target images so that we can normalize them.\n", + "Why is normalization important? \n", + "Should we normalize the input and ground truth data the same way? \n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "solution" + ] + }, + "source": [ + "Normalization brings the data's values into a standardized range, making the magnitude of gradients suitable for the default learning rate. \n", + "The target noise-free images have a much higher intensity than the noisy input images.\n", + "They need to be normalized using their own statistics to bring them into the same range." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Calculate the mean and std of the train dataset\n", + "train_mean = train_images_array.mean()\n", + "train_std = train_images_array.std()\n", + "target_mean = train_targets_array.mean()\n", + "target_std = train_targets_array.std()\n", + "print(f\"Train mean: {train_mean}, std: {train_std}\")\n", + "print(f\"Target mean: {target_mean}, std: {target_std}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These functions will be used to normalize the data and perform data augmentation as it is loaded." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def normalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Normalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Normalized array.\n", + " \"\"\"\n", + " return (image - mean) / std\n", + "\n", + "\n", + "def _flip_and_rotate(\n", + " image: np.ndarray, rotate_state: int, flip_state: int\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Apply the given number of 90 degrees rotations and flip to an array.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " rotate_state : int\n", + " Number of 90 degree rotations to apply.\n", + " flip_state : int\n", + " 0 or 1, whether to flip the array or not.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Flipped and rotated array.\n", + " \"\"\"\n", + " rotated = np.rot90(image, k=rotate_state, axes=(-2, -1))\n", + " flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated\n", + " return flipped.copy()\n", + "\n", + "\n", + "def augment_batch(\n", + " patch: np.ndarray,\n", + " target: np.ndarray,\n", + " seed: int = 42,\n", + ") -> Tuple[np.ndarray, ...]:\n", + " \"\"\"\n", + " Apply augmentation function to patches and masks.\n", + "\n", + " Parameters\n", + " ----------\n", + " patch : np.ndarray\n", + " Array containing single image or patch, 2D or 3D with masked pixels.\n", + " original_image : np.ndarray\n", + " Array containing original image or patch, 2D or 3D.\n", + " mask : np.ndarray\n", + " Array containing only masked pixels, 2D or 3D.\n", + " seed : int, optional\n", + " Seed for random number generator, controls the rotation and flipping.\n", + "\n", + " Returns\n", + " -------\n", + " Tuple[np.ndarray, ...]\n", + " Tuple of augmented arrays.\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed=seed)\n", + " rotate_state = rng.integers(0, 4)\n", + " flip_state = rng.integers(0, 2)\n", + " return (\n", + " _flip_and_rotate(patch, rotate_state, flip_state),\n", + " _flip_and_rotate(target, rotate_state, flip_state),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Defining the Dataset\n", + "\n", + "Here we're defining the basic pytorch dataset class that will be used to load the data. This class will be used to load the data and apply the normalization and augmentation functions to the data as it is loaded.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Define a Dataset\n", + "class CAREDataset(Dataset): # CAREDataset inherits from the PyTorch Dataset class\n", + " def __init__(\n", + " self, image_data: np.ndarray, target_data: np.ndarray, apply_augmentations=False\n", + " ):\n", + " # these are the \"members\" of the CAREDataset\n", + " self.image_data = image_data\n", + " self.target_data = target_data\n", + " self.patch_augment = apply_augmentations\n", + "\n", + " def __len__(self):\n", + " \"\"\"Return the total number of patches.\n", + "\n", + " This method is called when applying `len(...)` to an instance of our class\n", + " \"\"\"\n", + " return self.image_data.shape[\n", + " 0\n", + " ] # Your code here, define the total number of patches\n", + "\n", + " def __getitem__(self, index):\n", + " \"\"\"Return a single pair of patches.\"\"\"\n", + "\n", + " # get patch\n", + " patch = self.image_data[index]\n", + "\n", + " # get target\n", + " target = self.target_data[index]\n", + "\n", + " # Apply transforms\n", + " if self.patch_augment:\n", + " patch, target = augment_batch(patch=patch, target=target)\n", + "\n", + " # Normalize the patch\n", + " patch = normalize(patch, train_mean, train_std)\n", + " target = normalize(target, target_mean, target_std)\n", + "\n", + " return patch[np.newaxis].astype(np.float32), target[np.newaxis].astype(\n", + " np.float32\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test the dataset\n", + "train_dataset = CAREDataset(\n", + " image_data=train_images_patches, target_data=train_targets_patches\n", + ")\n", + "val_dataset = CAREDataset(\n", + " image_data=val_images_patches, target_data=val_targets_patches\n", + ")\n", + "\n", + "# what is the dataset length?\n", + "assert len(train_dataset) == train_images_patches.shape[0], \"Dataset length is wrong\"\n", + "\n", + "# check the normalization\n", + "assert train_dataset[42][0].max() <= 10, \"Patch isn't normalized properly\"\n", + "assert train_dataset[42][1].max() <= 10, \"Target patch isn't normalized properly\"\n", + "\n", + "# check the get_item function\n", + "assert train_dataset[42][0].shape == (1, *patch_size), \"Patch size is wrong\"\n", + "assert train_dataset[42][1].shape == (1, *patch_size), \"Target patch size is wrong\"\n", + "assert train_dataset[42][0].dtype == np.float32, \"Patch dtype is wrong\"\n", + "assert train_dataset[42][1].dtype == np.float32, \"Target patch dtype is wrong\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training and validation data are stored as an instance of a `Dataset`. \n", + "This describes how each image should be loaded.\n", + "Now we will prepare them to be fed into the model with a `Dataloader`.\n", + "\n", + "This will use the Dataset to load individual images and organise them into batches.\n", + "The Dataloader will shuffle the data at the start of each epoch, outputting different random batches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the dataset and create a DataLoader\n", + "train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Data

\n", + "\n", + "In this section, we prepared paired training data. \n", + "The steps were:\n", + "1) Loading the images.\n", + "2) Cropping them into patches.\n", + "3) Checking the patches visually.\n", + "4) Creating an instance of a pytorch dataset and dataloader.\n", + "\n", + "You'll see a similar preparation procedure followed for most deep learning vision tasks.\n", + "\n", + "Next, we'll use this data to train a denoising model.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 2: Training the model\n", + "\n", + "Image restoration task is very similar to the semantic segmentation task we have done in the previous exercise. We can use the same UNet model and just need to adapt a few things.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image](nb_data/carenet.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate the model\n", + "\n", + "We'll be using the model from the previous exercise, so we need to load the relevant module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model\n", + "model = UNet(depth=2, in_channels=1, out_channels=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 1: Loss function

\n", + "\n", + "CARE trains image to image, therefore we need a different loss function compared to the segmentation task (image to mask). Can you think of a suitable loss function?\n", + "\n", + "*hint: look in the `torch.nn` module of PyTorch ([link](https://pytorch.org/docs/stable/nn.html#loss-functions)).*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "loss = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "loss = torch.nn.MSELoss()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 2: Optimizer

\n", + "\n", + "Similarly, define the optimizer. No need to be too inventive here!\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "optimizer = #### YOUR CODE HERE ####" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(\n", + " model.parameters(), lr=1e-4\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Here we will train a CARE model using classes and functions you defined in the previous tasks.\n", + "We're using the same training loop as in the semantic segmentation exercise.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 3: Tensorboard

\n", + "\n", + "We'll monitor the training of all models in 05_image_restoration using Tensorboard. \n", + "This is a program that plots the training and validation loss of networks as they train, and can also show input/output image pairs.\n", + "Follow these steps to launch Tensorboard.\n", + "\n", + "1) Open the extensions panel in VS Code. Look for this icon. \n", + "\n", + "![image](nb_data/extensions.png)\n", + "\n", + "2) Search Tensorboard and install and install the extension published by Microsoft.\n", + "3) Start training. Run the cell below to begin training the model and generating logs.\n", + "3) Once training is started. Open the command palette (ctrl+shift+p), search for Python: Launch Tensorboard and hit enter.\n", + "4) When prompted, select \"Select another folder\" and enter the path to the `01_CARE/runs/` directory.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In tensorboard, click the SCALARS tab to see the training and validation loss curves. \n", + "At the end of each epoch, refresh Tensorboard using the button in the top right to see the latest loss.\n", + "\n", + "Click the IMAGES tab to see the noisy inputs, denoised outputs and clean targets.\n", + "These are updated at the end of each epoch too." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "n_epochs = 5\n", + "device = \"cuda\" if cuda.is_available() else \"cpu\"\n", + "model.to(device)\n", + "\n", + "# tensorboard\n", + "tb_logger = SummaryWriter(\"runs/Unet\"+datetime.now().strftime('%d%H-%M%S'))\n", + "def log_image(image, tag, logger, step):\n", + " normalised_image = image.cpu().numpy()\n", + " normalised_image = normalised_image - np.percentile(normalised_image, 1)\n", + " normalised_image = normalised_image / np.percentile(normalised_image, 99)\n", + " normalised_image = np.clip(normalised_image, 0, 1)\n", + " logger.add_images(tag=tag, img_tensor=normalised_image, global_step=step)\n", + "\n", + "\n", + "train_losses = []\n", + "val_losses = []\n", + "\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " for i, (image_batch, target_batch) in enumerate(train_dataloader):\n", + " batch = image_batch.to(device)\n", + " target = target_batch.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = model(batch)\n", + " train_loss = loss(output, target)\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " if i % 10 == 0:\n", + " print(f\"Epoch: {epoch}, Batch: {i}, Loss: {train_loss.item()}\")\n", + "\n", + " model.eval()\n", + "\n", + " with no_grad():\n", + " val_loss = 0\n", + " for i, (batch, target) in enumerate(val_dataloader):\n", + " batch = batch.to(device)\n", + " target = target.to(device)\n", + "\n", + " output = model(batch)\n", + " val_loss = loss(output, target)\n", + "\n", + " # log tensorboard\n", + " step = epoch * len(train_dataloader)\n", + " tb_logger.add_scalar(tag=\"train_loss\", scalar_value=train_loss, global_step=step)\n", + " tb_logger.add_scalar(tag=\"val_loss\", scalar_value=val_loss, global_step=step)\n", + "\n", + " # we always log the last validation images\n", + " log_image(batch, tag=\"val_input\", logger=tb_logger, step=step)\n", + " log_image(target, tag=\"val_target\", logger=tb_logger, step=step)\n", + " log_image(output, tag=\"val_prediction\", logger=tb_logger, step=step)\n", + "\n", + " print(f\"Validation loss: {val_loss.item()}\")\n", + "\n", + " # Save the losses for plotting\n", + " train_losses.append(train_loss.item())\n", + " val_losses.append(val_loss.item())\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot the loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot training and validation losses\n", + "plt.figure(figsize=(10, 5))\n", + "plt.plot(train_losses)\n", + "plt.plot(val_losses)\n", + "plt.xlabel(\"Iterations\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.legend([\"Train loss\", \"Validation loss\"])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 2: Training

\n", + "\n", + "In this section, we created and trained a UNet for denoising.\n", + "We:\n", + "1) Instantiated the model with random weights.\n", + "2) Chose a loss function to compare the output image to the ground truth clean image.\n", + "3) Chose an optimizer to minimize that loss function.\n", + "4) Trained the model with this optimizer.\n", + "5) Examined the training and validation loss curves to see how well our model trained.\n", + "\n", + "Next, we'll load a test set of noisy images and see how well our model denoises them.\n", + "
\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Part 3: Predicting on the test dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the dataset for the test data\n", + "test_dataset = CAREDataset(\n", + " image_data=test_images_array, target_data=test_targets_array\n", + ")\n", + "test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 4: Predict using the correct mean/std

\n", + "\n", + "In Part 1 we normalized the inputs and the targets before feeding them into the model. This means that the model will output normalized clean images, but we'd like them to be on the same scale as the real clean images.\n", + "\n", + "Recall the variables we used to normalize the data in Part 1, and use them denormalize the output of the model.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def denormalize(\n", + " image: np.ndarray,\n", + " mean: float = 0.0,\n", + " std: float = 1.0,\n", + ") -> np.ndarray:\n", + " \"\"\"\n", + " Denormalize an image with given mean and standard deviation.\n", + "\n", + " Parameters\n", + " ----------\n", + " image : np.ndarray\n", + " Array containing single image or patch, 2D or 3D.\n", + " mean : float, optional\n", + " Mean value for normalization, by default 0.0.\n", + " std : float, optional\n", + " Standard deviation value for normalization, by default 1.0.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Denormalized array.\n", + " \"\"\"\n", + " return image * std + mean" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), #### YOUR CODE HERE ####))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# Define the prediction loop\n", + "predictions = []\n", + "\n", + "model.eval()\n", + "with no_grad():\n", + " for i, (image_batch, target_batch) in enumerate(test_dataloader):\n", + " image_batch = image_batch.to(device)\n", + " target_batch = target_batch.to(device)\n", + " output = model(image_batch)\n", + "\n", + " # Save the predictions for visualization\n", + " predictions.append(denormalize(output.cpu().numpy(), train_mean, train_std))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(3, 2, figsize=(15, 15))\n", + "ax[0, 0].imshow(test_images_array[0].squeeze(), cmap=\"magma\")\n", + "ax[0, 0].set_title(\"Test image\")\n", + "ax[0, 1].imshow(predictions[0][0].squeeze(), cmap=\"magma\")\n", + "ax[0, 1].set_title(\"Prediction\")\n", + "ax[1, 0].imshow(test_images_array[1].squeeze(), cmap=\"magma\")\n", + "ax[1, 0].set_title(\"Test image\")\n", + "ax[1, 1].imshow(predictions[1][0].squeeze(), cmap=\"magma\")\n", + "ax[1, 1].set_title(\"Prediction\")\n", + "ax[2, 0].imshow(test_images_array[2].squeeze(), cmap=\"magma\")\n", + "ax[2, 0].set_title(\"Test image\")\n", + "ax[2, 1].imshow(predictions[2][0].squeeze(), cmap=\"magma\")\n", + "plt.tight_layout()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Predicting

\n", + "\n", + "In this section, we evaluated the performance of our denoiser.\n", + "We:\n", + "1) Created a CAREDataset and Dataloader for a prediction loop.\n", + "2) Ran a prediction loop on the test data.\n", + "3) Examined the outputs.\n", + "\n", + "This notebook has shown how matched pairs of noisy and clean images can train a UNet to denoise, but what if we don't have any clean images? In the next notebook, we'll try Noise2Void, a method for training a UNet to denoise with only noisy images.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/01_CARE/transforms.py b/01_CARE/transforms.py new file mode 100755 index 0000000..3aef43e --- /dev/null +++ b/01_CARE/transforms.py @@ -0,0 +1,110 @@ +import numpy as np +from typing import Tuple + + +def normalize( + image: np.ndarray, + mean: float = 0.0, + std: float = 1.0, +) -> np.ndarray: + """ + Normalize an image with given mean and standard deviation. + + Parameters + ---------- + image : np.ndarray + Array containing single image or patch, 2D or 3D. + mean : float, optional + Mean value for normalization, by default 0.0. + std : float, optional + Standard deviation value for normalization, by default 1.0. + + Returns + ------- + np.ndarray + Normalized array. + """ + return (image - mean) / std + + +def denormalize( + image: np.ndarray, + mean: float = 0.0, + std: float = 1.0, +) -> np.ndarray: + """ + Denormalize an image with given mean and standard deviation. + + Parameters + ---------- + image : np.ndarray + Array containing single image or patch, 2D or 3D. + mean : float, optional + Mean value for normalization, by default 0.0. + std : float, optional + Standard deviation value for normalization, by default 1.0. + + Returns + ------- + np.ndarray + Denormalized array. + """ + return image * std + mean + + +def _flip_and_rotate( + image: np.ndarray, rotate_state: int, flip_state: int +) -> np.ndarray: + """ + Apply the given number of 90 degrees rotations and flip to an array. + + Parameters + ---------- + image : np.ndarray + Array containing single image or patch, 2D or 3D. + rotate_state : int + Number of 90 degree rotations to apply. + flip_state : int + 0 or 1, whether to flip the array or not. + + Returns + ------- + np.ndarray + Flipped and rotated array. + """ + rotated = np.rot90(image, k=rotate_state, axes=(-2, -1)) + flipped = np.flip(rotated, axis=-1) if flip_state == 1 else rotated + return flipped.copy() + + +def augment_batch( + patch: np.ndarray, + target: np.ndarray, + seed: int = 42, +) -> Tuple[np.ndarray, ...]: + """ + Apply augmentation function to patches and masks. + + Parameters + ---------- + patch : np.ndarray + Array containing single image or patch, 2D or 3D with masked pixels. + original_image : np.ndarray + Array containing original image or patch, 2D or 3D. + mask : np.ndarray + Array containing only masked pixels, 2D or 3D. + seed : int, optional + Seed for random number generator, controls the rotation and falipping. + + Returns + ------- + Tuple[np.ndarray, ...] + Tuple of augmented arrays. + """ + rng = np.random.default_rng(seed=seed) + rotate_state = rng.integers(0, 4) + flip_state = rng.integers(0, 2) + return ( + _flip_and_rotate(patch, rotate_state, flip_state), + _flip_and_rotate(target, rotate_state, flip_state), + ) diff --git a/01_CARE/unet.py b/01_CARE/unet.py new file mode 100755 index 0000000..0a79d49 --- /dev/null +++ b/01_CARE/unet.py @@ -0,0 +1,282 @@ +import math +import torch +import torch.nn as nn +import numpy as np + + +class ConvBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: str = "same", + ): + """ A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU. + + Args: + in_channels (int): The number of input channels for this conv block. Depends on + the layer and side of the U-Net and the hyperparameters. + out_channels (int): The number of output channels for this conv block. Depends on + the layer and side of the U-Net and the hyperparameters. + kernel_size (int): The size of the kernel. A kernel size of N signifies an + NxN square kernel. + padding (str, optional): The type of padding to use. Options are "same" or "valid". + Defaults to "same". + """ + super().__init__() + + # determine padding size based on method + if padding in ("VALID", "valid"): + pad = 0 + elif padding in ("SAME", "same"): + pad = kernel_size // 2 + else: + raise RuntimeError("invalid string value for padding") + + # define layers in conv pass + self.conv_pass = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=pad + ), + torch.nn.ReLU(), + torch.nn.Conv2d( + out_channels, out_channels, kernel_size=kernel_size, padding=pad + ), + torch.nn.ReLU(), + ) + + for _name, layer in self.named_modules(): + if isinstance(layer, torch.nn.Conv2d): + torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu") + + def forward(self, x): + return self.conv_pass(x) + + +class Downsample(torch.nn.Module): + def __init__(self, downsample_factor: int): + super().__init__() + + self.downsample_factor = downsample_factor + + self.down = torch.nn.MaxPool2d( + downsample_factor + ) + + def check_valid(self, image_size: tuple[int, int]) -> bool: + """Check if the downsample factor evenly divides each image dimension + """ + for dim in image_size: + if dim % self.downsample_factor != 0: + return False + return True + + def forward(self, x): + if not self.check_valid(tuple(x.size()[-2:])): + raise RuntimeError( + "Can not downsample shape %s with factor %s" + % (x.size(), self.downsample_factor) + ) + + return self.down(x) + + +class CropAndConcat(torch.nn.Module): + def crop(self, x, y): + """Center-crop x to match spatial dimensions given by y.""" + + x_target_size = x.size()[:-2] + y.size()[-2:] + + offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size)) + + slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size)) + + return x[slices] + + def forward(self, encoder_output, upsample_output): + encoder_cropped = self.crop(encoder_output, upsample_output) + + return torch.cat([encoder_cropped, upsample_output], dim=1) + +class OutputConv(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + activation: str | None = None, # Accepts the name of any torch activation function (e.g., ``ReLU`` for ``torch.nn.ReLU``). + ): + super().__init__() + self.final_conv = torch.nn.Conv2d(in_channels, out_channels, 1, padding=0) # leave this out + if activation is None: + self.activation = None + else: + self.activation = getattr(torch.nn, activation)() + + def forward(self, x): + x = self.final_conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class UNet(torch.nn.Module): + def __init__( + self, + depth: int, + in_channels: int, + out_channels: int = 1, + final_activation: str | None = None, + num_fmaps: int = 64, + fmap_inc_factor: int = 2, + downsample_factor: int = 2, + kernel_size: int = 3, + padding: str = "same", + upsample_mode: str = "nearest", + ): + """A U-Net for 2D input that expects tensors shaped like:: + ``(batch, channels, height, width)``. + Args: + depth: + The number of levels in the U-Net. 2 is the smallest that really + makes sense for the U-Net architecture, as a one layer U-Net is + basically just 2 conv blocks. + in_channels: + The number of input channels in your dataset. + out_channels (optional): + How many output channels you want. Depends on your task. Defaults to 1. + final_activation (optional): + What activation to use in your final output block. Depends on your task. + Defaults to None. + num_fmaps (optional): + The number of feature maps in the first layer. Defaults to 64. + fmap_inc_factor (optional): + By how much to multiply the number of feature maps between + layers. Layer ``l`` will have ``num_fmaps*fmap_inc_factor**l`` + feature maps. Defaults to 2. + downsample_factor (optional): + Factor to use for down- and up-sampling the feature maps between layers. + Defaults to 2. + kernel_size (optional): + Kernel size to use in convolutions on both sides of the UNet. + Defaults to 3. + padding (optional): + How to pad convolutions. Either 'same' or 'valid'. Defaults to "same." + upsample_mode (optional): + The upsampling mode to pass to torch.nn.Upsample. Usually "nearest" + or "bilinear." Defaults to "nearest." + """ + + super().__init__() + + self.depth = depth + self.in_channels = in_channels + self.out_channels = out_channels + self.final_activation = final_activation + self.num_fmaps = num_fmaps + self.fmap_inc_factor = fmap_inc_factor + self.downsample_factor = downsample_factor + self.kernel_size = kernel_size + self.padding = padding + self.upsample_mode = upsample_mode + + # left convolutional passes + self.left_convs = torch.nn.ModuleList() + for level in range(self.depth): + fmaps_in, fmaps_out = self.compute_fmaps_encoder(level) + self.left_convs.append( + ConvBlock( + fmaps_in, + fmaps_out, + self.kernel_size, + self.padding + ) + ) + + # right convolutional passes + self.right_convs = torch.nn.ModuleList() + for level in range(self.depth - 1): + fmaps_in, fmaps_out = self.compute_fmaps_decoder(level) + self.right_convs.append( + ConvBlock( + fmaps_in, + fmaps_out, + self.kernel_size, + self.padding, + ) + ) + + self.downsample = Downsample(self.downsample_factor) + self.upsample = torch.nn.Upsample( + scale_factor=self.downsample_factor, + mode=self.upsample_mode, + ) + self.crop_and_concat = CropAndConcat() + self.final_conv = OutputConv( + self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation + ) + + def compute_fmaps_encoder(self, level: int) -> tuple[int, int]: + """Compute the number of input and output feature maps for + a conv block at a given level of the UNet encoder (left side). + + Args: + level (int): The level of the U-Net which we are computing + the feature maps for. Level 0 is the input level, level 1 is + the first downsampled layer, and level=depth - 1 is the bottom layer. + + Output (tuple[int, int]): The number of input and output feature maps + of the encoder convolutional pass in the given level. + """ + if level == 0: # Leave out function + fmaps_in = self.in_channels + else: + fmaps_in = self.num_fmaps * self.fmap_inc_factor ** (level - 1) + + fmaps_out = self.num_fmaps * self.fmap_inc_factor**level + return fmaps_in, fmaps_out + + def compute_fmaps_decoder(self, level: int) -> tuple[int, int]: + """Compute the number of input and output feature maps for a conv block + at a given level of the UNet decoder (right side). Note: + The bottom layer (depth - 1) is considered an "encoder" conv pass, + so this function is only valid up to depth - 2. + + Args: + level (int): The level of the U-Net which we are computing + the feature maps for. Level 0 is the input level, level 1 is + the first downsampled layer, and level=depth - 1 is the bottom layer. + + Output (tuple[int, int]): The number of input and output feature maps + of the encoder convolutional pass in the given level. + """ + fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level) # Leave out function + concat_fmaps = self.compute_fmaps_encoder(level)[ + 1 + ] # The channels that come from the skip connection + fmaps_in = concat_fmaps + self.num_fmaps * self.fmap_inc_factor ** (level + 1) + + return fmaps_in, fmaps_out + + def forward(self, x): + # left side + convolution_outputs = [] + layer_input = x + for i in range(self.depth - 1): # leave out center of for loop + conv_out = self.left_convs[i](layer_input) + convolution_outputs.append(conv_out) + downsampled = self.downsample(conv_out) + layer_input = downsampled + + # bottom + conv_out = self.left_convs[-1](layer_input) + layer_input = conv_out + + # right + for i in range(0, self.depth-1)[::-1]: # leave out center of for loop + upsampled = self.upsample(layer_input) + concat = self.crop_and_concat(convolution_outputs[i], upsampled) + conv_output = self.right_convs[i](concat) + layer_input = conv_output + + return self.final_conv(layer_input) \ No newline at end of file diff --git a/02_Noise2Void/exercise.ipynb b/02_Noise2Void/exercise.ipynb new file mode 100644 index 0000000..215b5cb --- /dev/null +++ b/02_Noise2Void/exercise.ipynb @@ -0,0 +1,868 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Noise2Void\n", + "\n", + "In the first exercise, we denoised images with CARE using supervised training. As \n", + "discussed during the lecture, ground-truth data is not always available in life \n", + "sciences. But no panic, Noise2Void is here to help!\n", + "\n", + "Indeed Noise2Void is a self-supervised algorithm, meaning that it trains on the data\n", + "itself and does not require clean images. The idea is to predict the value of a masked\n", + "pixels based on the information from the surrounding pixels. Two underlying hypothesis\n", + "allow N2V to work: the structures are continuous and the noise is pixel-independent, \n", + "that is to say the amount of noise in one pixel is independent from the amount of noise\n", + "in the surrounding pixels. Fortunately for us, it is very often the case in microscopy images!\n", + "\n", + "If N2V does not require pairs of noisy and clean images, then how does it train?\n", + "\n", + "First it selects random pixels in each patch, then it masks them. The masking is \n", + "not done by setting their value to 0 (which could disturb the network since it is an\n", + "unexpected value) but by replacing the value with that of one of the neighboring pixels.\n", + "\n", + "Then, the network is trained to predict the value of the masked pixels. Since the masked\n", + "value is different from the original value, the network needs to use the information\n", + "contained in all the pixels surrounding the masked pixel. If the noise is pixel-independent,\n", + "then the network cannot predict the amount of noise in the original pixel and it ends\n", + "up predicting a value close to the \"clean\", or denoised, value.\n", + "\n", + "In this notebook, we will use an existing library called [Careamics](https://careamics.github.io)\n", + "that includes N2V and other algorithms:\n", + "\n", + "

\n", + " \n", + "

\n", + "\n", + "\n", + "## References\n", + "\n", + "- Alexander Krull, Tim-Oliver Buchholz, and Florian Jug. \"[Noise2Void - learning denoising from single noisy images.](https://openaccess.thecvf.com/content_CVPR_2019/html/Krull_Noise2Void_-_Learning_Denoising_From_Single_Noisy_Images_CVPR_2019_paper.html)\" Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition, 2019.\n", + "- Joshua Batson, and Loic Royer. \"[Noise2self: Blind denoising by self-supervision.](http://proceedings.mlr.press/v97/batson19a.html)\" International Conference on Machine Learning. PMLR, 2019." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "\n", + "

Objectives

\n", + " \n", + "- Understand how N2V masks pixels for training\n", + "- Learn how to use CAREamics to train N2V\n", + "- Think about pixel noise and noise correlation\n", + " \n", + "
\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Mandatory actions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import tifffile\n", + "\n", + "from careamics import CAREamist\n", + "from careamics.config import (\n", + " create_n2v_configuration,\n", + ")\n", + "from careamics.transforms import N2VManipulate\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "## Part 1 Visualize the masking algorithm\n", + "\n", + "In this first part, let's inspect how this pixel masking is done before training a N2V network!\n", + "\n", + "Before feeding patches to the network, a set of transformations, or augmentations, are \n", + "applied to them. For instance in microscopy, we usually apply random 90 degrees rotations\n", + "or flip the images. In Noise2Void, we apply one more transformation that replace random pixels\n", + "by a value from their surrounding.\n", + "\n", + "In CAREamics, the transformation is called `N2VManipulate`. It has different \n", + "parameters: `roi_size`, `masked_pixel_percentage` and `strategy`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a patch size for this exercise\n", + "dummy_patch_size = 10\n", + "\n", + "# Define masking parameters\n", + "roi_size = 3\n", + "masked_pixel_percentage = 10\n", + "strategy = 'uniform'" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 1: Explore the N2VManipulate parameters

\n", + "\n", + "Can you understand what `roi_size` and `masked_pixel_percentage` do? What can go wrong if they are too small or too high?\n", + "\n", + "\n", + "Run the cell below to observe the effects!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Create a dummy patch\n", + "patch = np.arange(dummy_patch_size**2).reshape(dummy_patch_size, dummy_patch_size)\n", + "\n", + "# The pixel manipulator expects a channel dimension, so we need to add it to the patch\n", + "patch = patch[np.newaxis]\n", + "\n", + "# Instantiate the pixel manipulator\n", + "manipulator = N2VManipulate(\n", + " roi_size=roi_size,\n", + " masked_pixel_percentage=masked_pixel_percentage,\n", + " strategy=strategy,\n", + ")\n", + "\n", + "# And apply it\n", + "masked_patch, original_patch, mask = manipulator(patch)\n", + "\n", + "# Visualize the masked patch and the mask\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(masked_patch[0])\n", + "ax[0].title.set_text(\"Manipulated patch\")\n", + "ax[1].imshow(mask[0], cmap=\"gray\")\n", + "ax[1].title.set_text(\"Mask\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Questions: Noise2Void masking strategy

\n", + "\n", + "\n", + "So what's really happening on a technical level? \n", + "\n", + "In the basic setting N2V algorithm replaces certain pixels with the values from the vicinity\n", + "Other masking stategies also exist, e.g. median, where replacement value is the median off all the pixels inside the region of interest.\n", + "\n", + "Feel free to play around with the ROI size, patch size and masked pixel percentage parameters\n", + "\n", + "
\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: N2V masking

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 2: Prepare the data\n", + "\n", + "Now that we understand how the masking works, let's train a Noise2Void network! We will\n", + "use a scanning electron microscopy image (SEM)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-N2V_SEM.unzip\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train.tif\"\n", + "validation_images_path = root_path / \"validation.tif\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Visualize training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Load images\n", + "train_image = tifffile.imread(train_images_path)\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "plt.imshow(train_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Visualize validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "val_image = tifffile.imread(validation_images_path)\n", + "print(f\"Validation image shape: {val_image.shape}\")\n", + "plt.imshow(val_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: Create a configuration\n", + "\n", + "CAREamics can be configured either from a yaml file, or with an explicitly created config object.\n", + "In this note book we will create the config object using helper functions. CAREamics will \n", + "validate all the parameters and will output explicit error if some parameters or a combination of parameters isn't allowed. It will also provide default values for missing parameters.\n", + "\n", + "The helper function limits the parameters to what is relevant for N2V, here is a break down of these parameters:\n", + "\n", + "- `experiment_name`: name used to identify the experiment\n", + "- `data_type`: data type, in CAREamics it can only be `tiff` or `array` \n", + "- `axes`: axes of the data, here it would be `YX`. Remember: pytorch and numpy order axes in reverse of what you might be used to. If the data were 3D, the axes would be `ZYX`.\n", + "- `patch_size`: size of the patches used for training\n", + "- `batch_size`: size of each batch\n", + "- `num_epochs`: number of epochs\n", + "\n", + "\n", + "There are also optional parameters, for more fine grained details:\n", + "\n", + "- `use_augmentations`: whether to use augmentations (flip and rotation)\n", + "- `use_n2v2`: whether to use N2V2, a N2V variant (see optional exercise)\n", + "- `n_channels`: the number of channels \n", + "- `roi_size`: size of the N2V manipulation region (remember that parameter?)\n", + "- `masked_pixel_percentage`: percentage of pixels to mask\n", + "- `logger`: which logger to use\n", + "\n", + "\n", + "Have a look at the [documentation](https://careamics.github.io) to see the full list of parameters and \n", + "their use!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Create a configuration using the helper function\n", + "training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v_sem\",\n", + " data_type=\"tiff\",\n", + " axes=\"YX\",\n", + " patch_size=[64, 64],\n", + " batch_size=128,\n", + " num_epochs=10,\n", + " roi_size=3,\n", + " masked_pixel_percentage=0.05,\n", + " logger=\"tensorboard\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a configuration using the helper function\n", + "training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v_sem\",\n", + " data_type=\"tiff\",\n", + " axes=\"YX\",\n", + " patch_size=[64, 64],\n", + " batch_size=128,\n", + " num_epochs=10,\n", + " roi_size=11,\n", + " masked_pixel_percentage=0.2,\n", + " logger=\"tensorboard\"\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Initialize the Model\n", + "\n", + "Let's instantiate the model with the configuration we just created. CAREamist is the main class of the library, it will handle creation of the data pipeline, the model, training and inference methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist = CAREamist(source=training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4: Train\n", + "\n", + "Here, we need to specify the paths to training and validation data. We can point to a folder containing \n", + "the data or to a single file. If it fits in memory, then CAREamics will load everything and train on it. If it doesn't, then CAREamics will load the data file by file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist.train(train_source=train_images_path, val_source=validation_images_path)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 2: Tensorboard

\n", + "\n", + "Remember the configuration? Didn't we set `logger` to `tensorboard`? Then we can visualize the loss curve!\n", + "\n", + "Open Tensorboard in VS Code (check Task 3 in 01_CARE) to monitor training. \n", + "Logs for this model are stored in the `02_N2V/logs/` folder.\n", + "
\n", + "\n", + "

Question: N2V loss curve

\n", + "\n", + "Do you remember what the loss is in Noise2Void? What is the meaning of the loss curve in that case? Can\n", + "it be easily interpreted?\n", + "
\n", + "\n", + "

Checkpoint 2: Training Noise2Void

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "We trained, but how well did it do?" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 5. Prediction\n", + "\n", + "In order to predict on an image, we also need to specify the path. We also typically need\n", + "to cut the image into patches, predict on each patch and then stitch the patches back together.\n", + "\n", + "To make the process faster, we can choose bigger tiles than the patches used during training. By default CAREamics uses tiled prediction to handle large images. The tile size can be set via the `tile_size` parameter. Tile overlap is computed automatically based on the network architecture." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "preds = careamist.predict(source=train_images_path, tile_size=(256, 256))[0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the full image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image, cmap=\"gray\")\n", + "ax[1].imshow(preds.squeeze(), cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Question: Inspect the image closely

\n", + "\n", + "If you got a good result, try to inspect the image closely. For instance, the default\n", + "window we used for the close-up image:\n", + "\n", + "`y_start` = 200\n", + "\n", + "`y_end` = 450\n", + "\n", + "`x_start` = 600\n", + "\n", + "`x_end` = 850\n", + "\n", + "Do you see anything peculiar in the fine grained details? What could be the reason for that?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Show a close up image\n", + "y_start = 200\n", + "y_end = 450\n", + "x_start = 600\n", + "x_end = 850\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image[y_start:y_end, x_start:x_end], cmap=\"gray\")\n", + "ax[1].imshow(preds.squeeze()[y_start:y_end, x_start:x_end], cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Question: Check the residuals

\n", + "\n", + "Compute the absolute difference between original and denoised image. What do you see? \n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "plt.imshow(preds.squeeze() - train_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 4(Optional): Improving the results

\n", + "\n", + "CAREamics configuration won't allow you to use parameters which are clearly wrong. However, there are many parameters that can be tuned to improve the results. Try to play around with the `roi_size` and `masked_pixel_percentage` and see if you can improve the results.\n", + "\n", + "Do the fine-grained structures observed in during the closer look at the image disappear?\n", + "\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### How to predict without training?\n", + "\n", + "Here again, CAREamics provides a way to create a CAREamist from a checkpoint only,\n", + "allowing predicting without having to retrain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Instantiate a CAREamist from a checkpoint\n", + "other_careamist = CAREamist(source=\"checkpoints/last.ckpt\")\n", + "\n", + "# And predict\n", + "new_preds = other_careamist.predict(source=train_images_path, tile_size=(256, 256))[0]\n", + "\n", + "# Show the full image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image, cmap=\"gray\")\n", + "ax[1].imshow(new_preds.squeeze(), cmap=\"gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_image[:128, :128].shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Prediction

\n", + "
\n", + "\n", + "
\n", + "\n", + "## Part 6: Exporting the model\n", + "\n", + "Have you heard of the [BioImage Model Zoo](https://bioimage.io/#/)? It provides a format for FAIR AI models and allows\n", + "researchers to exchange and reproduce models. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Export model as BMZ\n", + "careamist.export_to_bmz(\n", + " path_to_archive=\"n2v_model.zip\",\n", + " input_array=train_image[:128, :128],\n", + " friendly_model_name=\"SEM_N2V\",\n", + " authors= [{\"name\": \"Jane\", \"affiliation\": \"Doe University\"}],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 5: Train N2V(2) on a different dataset

\n", + "\n", + "As you remember from the lecture, N2V can only deal with the noise that is pixelwise independent. \n", + "\n", + "Use these cells to train on a different dataset: Mito Confocal, which has noise that is not pixelwise independent, but is spatially correlated. This will be loaded in the following cell.\n", + "\n", + "In the next cells we'll show you how the result of training a N2V model on this dataset looks.\n", + "\n", + "In the next exercise of the course we'll learn how to deal with this kind of noise! \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "mito_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "mito_image = tifffile.imread(mito_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure the model\n", + "mito_training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v2_mito\",\n", + " data_type=\"array\",\n", + " axes=\"SYX\", # <-- we are adding S because we have a stack of images\n", + " patch_size=[64, 64],\n", + " batch_size=64,\n", + " num_epochs=10,\n", + " logger=\"tensorboard\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist = CAREamist(source=mito_training_config)\n", + "careamist.train(\n", + " train_source=mito_image,\n", + " val_percentage=0.1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = careamist.predict(\n", + " source=mito_image[:1], # <-- we predict on a small subset\n", + " data_type=\"array\",\n", + " tile_size=(64, 64),\n", + ")[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following cell, look closely at the denoising result of applying N2V to data with spatially correlated noise. Zoom in and see if you can find the horizontal artifacts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(mito_image, 1)\n", + "vmax = np.percentile(mito_image, 99)\n", + "\n", + "y_start = 0\n", + "y_end = 1024\n", + "x_start = 0\n", + "x_end = 1024\n", + "\n", + "# Feel free to play around with the visualization\n", + "_, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(preds[0, 0, 600:700, 300:400], vmin=vmin, vmax=vmax)\n", + "ax[0].title.set_text(\"Predicted\")\n", + "ax[1].imshow(mito_image[0, 600:700, 300:400], vmin=vmin, vmax=vmax)\n", + "ax[1].title.set_text(\"Original\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Checkpoint 4: Dealing with artifacts

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Take away questions

\n", + "\n", + "- Which is the best saved checkpoint for Noise2Void, the one at the end of the training or the one with lowest validation loss?\n", + "\n", + "- Is validation useful in Noise2Void?\n", + "\n", + "- We predicted on the same image we trained on, is that a good idea?\n", + "\n", + "- Can you reuse the model on another image?\n", + "\n", + "- Can you train on images with multiple channels? RGB images? Biological channels (GFP, RFP, DAPI)?\n", + "\n", + "- N2V training is unsupervised, how can you be sure that the training worked and is not hallucinating?\n", + "
\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

End of the exercise

\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/02_Noise2Void/solution.ipynb b/02_Noise2Void/solution.ipynb new file mode 100755 index 0000000..f687e54 --- /dev/null +++ b/02_Noise2Void/solution.ipynb @@ -0,0 +1,892 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Noise2Void\n", + "\n", + "In the first exercise, we denoised images with CARE using supervised training. As \n", + "discussed during the lecture, ground-truth data is not always available in life \n", + "sciences. But no panic, Noise2Void is here to help!\n", + "\n", + "Indeed Noise2Void is a self-supervised algorithm, meaning that it trains on the data\n", + "itself and does not require clean images. The idea is to predict the value of a masked\n", + "pixels based on the information from the surrounding pixels. Two underlying hypothesis\n", + "allow N2V to work: the structures are continuous and the noise is pixel-independent, \n", + "that is to say the amount of noise in one pixel is independent from the amount of noise\n", + "in the surrounding pixels. Fortunately for us, it is very often the case in microscopy images!\n", + "\n", + "If N2V does not require pairs of noisy and clean images, then how does it train?\n", + "\n", + "First it selects random pixels in each patch, then it masks them. The masking is \n", + "not done by setting their value to 0 (which could disturb the network since it is an\n", + "unexpected value) but by replacing the value with that of one of the neighboring pixels.\n", + "\n", + "Then, the network is trained to predict the value of the masked pixels. Since the masked\n", + "value is different from the original value, the network needs to use the information\n", + "contained in all the pixels surrounding the masked pixel. If the noise is pixel-independent,\n", + "then the network cannot predict the amount of noise in the original pixel and it ends\n", + "up predicting a value close to the \"clean\", or denoised, value.\n", + "\n", + "In this notebook, we will use an existing library called [Careamics](https://careamics.github.io)\n", + "that includes N2V and other algorithms:\n", + "\n", + "

\n", + " \n", + "

\n", + "\n", + "\n", + "## References\n", + "\n", + "- Alexander Krull, Tim-Oliver Buchholz, and Florian Jug. \"[Noise2Void - learning denoising from single noisy images.](https://openaccess.thecvf.com/content_CVPR_2019/html/Krull_Noise2Void_-_Learning_Denoising_From_Single_Noisy_Images_CVPR_2019_paper.html)\" Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition, 2019.\n", + "- Joshua Batson, and Loic Royer. \"[Noise2self: Blind denoising by self-supervision.](http://proceedings.mlr.press/v97/batson19a.html)\" International Conference on Machine Learning. PMLR, 2019." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "\n", + "

Objectives

\n", + " \n", + "- Understand how N2V masks pixels for training\n", + "- Learn how to use CAREamics to train N2V\n", + "- Think about pixel noise and noise correlation\n", + " \n", + "
\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Mandatory actions\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import tifffile\n", + "\n", + "from careamics import CAREamist\n", + "from careamics.config import (\n", + " create_n2v_configuration,\n", + ")\n", + "from careamics.transforms import N2VManipulate\n", + "\n", + "%matplotlib inline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "## Part 1 Visualize the masking algorithm\n", + "\n", + "In this first part, let's inspect how this pixel masking is done before training a N2V network!\n", + "\n", + "Before feeding patches to the network, a set of transformations, or augmentations, are \n", + "applied to them. For instance in microscopy, we usually apply random 90 degrees rotations\n", + "or flip the images. In Noise2Void, we apply one more transformation that replace random pixels\n", + "by a value from their surrounding.\n", + "\n", + "In CAREamics, the transformation is called `N2VManipulate`. It has different \n", + "parameters: `roi_size`, `masked_pixel_percentage` and `strategy`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a patch size for this exercise\n", + "dummy_patch_size = 10\n", + "\n", + "# Define masking parameters\n", + "roi_size = 3\n", + "masked_pixel_percentage = 10\n", + "strategy = 'uniform'" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 1: Explore the N2VManipulate parameters

\n", + "\n", + "Can you understand what `roi_size` and `masked_pixel_percentage` do? What can go wrong if they are too small or too high?\n", + "\n", + "\n", + "Run the cell below to observe the effects!\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Create a dummy patch\n", + "patch = np.arange(dummy_patch_size**2).reshape(dummy_patch_size, dummy_patch_size)\n", + "\n", + "# The pixel manipulator expects a channel dimension, so we need to add it to the patch\n", + "patch = patch[np.newaxis]\n", + "\n", + "# Instantiate the pixel manipulator\n", + "manipulator = N2VManipulate(\n", + " roi_size=roi_size,\n", + " masked_pixel_percentage=masked_pixel_percentage,\n", + " strategy=strategy,\n", + ")\n", + "\n", + "# And apply it\n", + "masked_patch, original_patch, mask = manipulator(patch)\n", + "\n", + "# Visualize the masked patch and the mask\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(masked_patch[0])\n", + "ax[0].title.set_text(\"Manipulated patch\")\n", + "ax[1].imshow(mask[0], cmap=\"gray\")\n", + "ax[1].title.set_text(\"Mask\")\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Questions: Noise2Void masking strategy

\n", + "\n", + "\n", + "So what's really happening on a technical level? \n", + "\n", + "In the basic setting N2V algorithm replaces certain pixels with the values from the vicinity\n", + "Other masking stategies also exist, e.g. median, where replacement value is the median off all the pixels inside the region of interest.\n", + "\n", + "Feel free to play around with the ROI size, patch size and masked pixel percentage parameters\n", + "\n", + "
\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: N2V masking

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 2: Prepare the data\n", + "\n", + "Now that we understand how the masking works, let's train a Noise2Void network! We will\n", + "use a scanning electron microscopy image (SEM)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the paths\n", + "root_path = Path(\"./../data\")\n", + "root_path = root_path / \"denoising-N2V_SEM.unzip\"\n", + "assert root_path.exists(), f\"Path {root_path} does not exist\"\n", + "\n", + "train_images_path = root_path / \"train.tif\"\n", + "validation_images_path = root_path / \"validation.tif\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Visualize training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Load images\n", + "train_image = tifffile.imread(train_images_path)\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "plt.imshow(train_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Visualize validation data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "val_image = tifffile.imread(validation_images_path)\n", + "print(f\"Validation image shape: {val_image.shape}\")\n", + "plt.imshow(val_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: Create a configuration\n", + "\n", + "CAREamics can be configured either from a yaml file, or with an explicitly created config object.\n", + "In this note book we will create the config object using helper functions. CAREamics will \n", + "validate all the parameters and will output explicit error if some parameters or a combination of parameters isn't allowed. It will also provide default values for missing parameters.\n", + "\n", + "The helper function limits the parameters to what is relevant for N2V, here is a break down of these parameters:\n", + "\n", + "- `experiment_name`: name used to identify the experiment\n", + "- `data_type`: data type, in CAREamics it can only be `tiff` or `array` \n", + "- `axes`: axes of the data, here it would be `YX`. Remember: pytorch and numpy order axes in reverse of what you might be used to. If the data were 3D, the axes would be `ZYX`.\n", + "- `patch_size`: size of the patches used for training\n", + "- `batch_size`: size of each batch\n", + "- `num_epochs`: number of epochs\n", + "\n", + "\n", + "There are also optional parameters, for more fine grained details:\n", + "\n", + "- `use_augmentations`: whether to use augmentations (flip and rotation)\n", + "- `use_n2v2`: whether to use N2V2, a N2V variant (see optional exercise)\n", + "- `n_channels`: the number of channels \n", + "- `roi_size`: size of the N2V manipulation region (remember that parameter?)\n", + "- `masked_pixel_percentage`: percentage of pixels to mask\n", + "- `logger`: which logger to use\n", + "\n", + "\n", + "Have a look at the [documentation](https://careamics.github.io) to see the full list of parameters and \n", + "their use!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# Create a configuration using the helper function\n", + "training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v_sem\",\n", + " data_type=\"tiff\",\n", + " axes=\"YX\",\n", + " patch_size=[64, 64],\n", + " batch_size=128,\n", + " num_epochs=10,\n", + " roi_size=3,\n", + " masked_pixel_percentage=0.05,\n", + " logger=\"tensorboard\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# Create a configuration using the helper function\n", + "training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v_sem\",\n", + " data_type=\"tiff\",\n", + " axes=\"YX\",\n", + " patch_size=[64, 64],\n", + " batch_size=128,\n", + " num_epochs=10,\n", + " roi_size=3,\n", + " masked_pixel_percentage=0.05,\n", + " logger=\"tensorboard\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a configuration using the helper function\n", + "training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v_sem\",\n", + " data_type=\"tiff\",\n", + " axes=\"YX\",\n", + " patch_size=[64, 64],\n", + " batch_size=128,\n", + " num_epochs=10,\n", + " roi_size=11,\n", + " masked_pixel_percentage=0.2,\n", + " logger=\"tensorboard\"\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Initialize the Model\n", + "\n", + "Let's instantiate the model with the configuration we just created. CAREamist is the main class of the library, it will handle creation of the data pipeline, the model, training and inference methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist = CAREamist(source=training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4: Train\n", + "\n", + "Here, we need to specify the paths to training and validation data. We can point to a folder containing \n", + "the data or to a single file. If it fits in memory, then CAREamics will load everything and train on it. If it doesn't, then CAREamics will load the data file by file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist.train(train_source=train_images_path, val_source=validation_images_path)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 2: Tensorboard

\n", + "\n", + "Remember the configuration? Didn't we set `logger` to `tensorboard`? Then we can visualize the loss curve!\n", + "\n", + "Open Tensorboard in VS Code (check Task 3 in 01_CARE) to monitor training. \n", + "Logs for this model are stored in the `02_N2V/logs/` folder.\n", + "
\n", + "\n", + "

Question: N2V loss curve

\n", + "\n", + "Do you remember what the loss is in Noise2Void? What is the meaning of the loss curve in that case? Can\n", + "it be easily interpreted?\n", + "
\n", + "\n", + "

Checkpoint 2: Training Noise2Void

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "We trained, but how well did it do?" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 5. Prediction\n", + "\n", + "In order to predict on an image, we also need to specify the path. We also typically need\n", + "to cut the image into patches, predict on each patch and then stitch the patches back together.\n", + "\n", + "To make the process faster, we can choose bigger tiles than the patches used during training. By default CAREamics uses tiled prediction to handle large images. The tile size can be set via the `tile_size` parameter. Tile overlap is computed automatically based on the network architecture." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "preds = careamist.predict(source=train_images_path, tile_size=(256, 256))[0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the full image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image, cmap=\"gray\")\n", + "ax[1].imshow(preds.squeeze(), cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Question: Inspect the image closely

\n", + "\n", + "If you got a good result, try to inspect the image closely. For instance, the default\n", + "window we used for the close-up image:\n", + "\n", + "`y_start` = 200\n", + "\n", + "`y_end` = 450\n", + "\n", + "`x_start` = 600\n", + "\n", + "`x_end` = 850\n", + "\n", + "Do you see anything peculiar in the fine grained details? What could be the reason for that?\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "# Show a close up image\n", + "y_start = 200\n", + "y_end = 450\n", + "x_start = 600\n", + "x_end = 850\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image[y_start:y_end, x_start:x_end], cmap=\"gray\")\n", + "ax[1].imshow(preds.squeeze()[y_start:y_end, x_start:x_end], cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Question: Check the residuals

\n", + "\n", + "Compute the absolute difference between original and denoised image. What do you see? \n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + } + }, + "outputs": [], + "source": [ + "plt.imshow(preds.squeeze() - train_image, cmap=\"gray\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 4(Optional): Improving the results

\n", + "\n", + "CAREamics configuration won't allow you to use parameters which are clearly wrong. However, there are many parameters that can be tuned to improve the results. Try to play around with the `roi_size` and `masked_pixel_percentage` and see if you can improve the results.\n", + "\n", + "Do the fine-grained structures observed in during the closer look at the image disappear?\n", + "\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### How to predict without training?\n", + "\n", + "Here again, CAREamics provides a way to create a CAREamist from a checkpoint only,\n", + "allowing predicting without having to retrain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Instantiate a CAREamist from a checkpoint\n", + "other_careamist = CAREamist(source=\"checkpoints/last.ckpt\")\n", + "\n", + "# And predict\n", + "new_preds = other_careamist.predict(source=train_images_path, tile_size=(256, 256))[0]\n", + "\n", + "# Show the full image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(train_image, cmap=\"gray\")\n", + "ax[1].imshow(new_preds.squeeze(), cmap=\"gray\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_image[:128, :128].shape" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 3: Prediction

\n", + "
\n", + "\n", + "
\n", + "\n", + "## Part 6: Exporting the model\n", + "\n", + "Have you heard of the [BioImage Model Zoo](https://bioimage.io/#/)? It provides a format for FAIR AI models and allows\n", + "researchers to exchange and reproduce models. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Export model as BMZ\n", + "careamist.export_to_bmz(\n", + " path_to_archive=\"n2v_model.zip\",\n", + " input_array=train_image[:128, :128],\n", + " friendly_model_name=\"SEM_N2V\",\n", + " authors= [{\"name\": \"Jane\", \"affiliation\": \"Doe University\"}],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Task 5: Train N2V(2) on a different dataset

\n", + "\n", + "As you remember from the lecture, N2V can only deal with the noise that is pixelwise independent. \n", + "\n", + "Use these cells to train on a different dataset: Mito Confocal, which has noise that is not pixelwise independent, but is spatially correlated. This will be loaded in the following cell.\n", + "\n", + "In the next cells we'll show you how the result of training a N2V model on this dataset looks.\n", + "\n", + "In the next exercise of the course we'll learn how to deal with this kind of noise! \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "mito_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "mito_image = tifffile.imread(mito_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure the model\n", + "mito_training_config = create_n2v_configuration(\n", + " experiment_name=\"dl4mia_n2v2_mito\",\n", + " data_type=\"array\",\n", + " axes=\"SYX\", # <-- we are adding S because we have a stack of images\n", + " patch_size=[64, 64],\n", + " batch_size=64,\n", + " num_epochs=10,\n", + " logger=\"tensorboard\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "careamist = CAREamist(source=mito_training_config)\n", + "careamist.train(\n", + " train_source=mito_image,\n", + " val_percentage=0.1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds = careamist.predict(\n", + " source=mito_image[:1], # <-- we predict on a small subset\n", + " data_type=\"array\",\n", + " tile_size=(64, 64),\n", + ")[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following cell, look closely at the denoising result of applying N2V to data with spatially correlated noise. Zoom in and see if you can find the horizontal artifacts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(mito_image, 1)\n", + "vmax = np.percentile(mito_image, 99)\n", + "\n", + "y_start = 0\n", + "y_end = 1024\n", + "x_start = 0\n", + "x_end = 1024\n", + "\n", + "# Feel free to play around with the visualization\n", + "_, ax = plt.subplots(1, 2, figsize=(10, 5))\n", + "ax[0].imshow(preds[0, 0, 600:700, 300:400], vmin=vmin, vmax=vmax)\n", + "ax[0].title.set_text(\"Predicted\")\n", + "ax[1].imshow(mito_image[0, 600:700, 300:400], vmin=vmin, vmax=vmax)\n", + "ax[1].title.set_text(\"Original\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Checkpoint 4: Dealing with artifacts

\n", + "
" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "

Take away questions

\n", + "\n", + "- Which is the best saved checkpoint for Noise2Void, the one at the end of the training or the one with lowest validation loss?\n", + "\n", + "- Is validation useful in Noise2Void?\n", + "\n", + "- We predicted on the same image we trained on, is that a good idea?\n", + "\n", + "- Can you reuse the model on another image?\n", + "\n", + "- Can you train on images with multiple channels? RGB images? Biological channels (GFP, RFP, DAPI)?\n", + "\n", + "- N2V training is unsupervised, how can you be sure that the training worked and is not hallucinating?\n", + "
\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

End of the exercise

\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/03_COSDD/bonus-exercise.ipynb b/03_COSDD/bonus-exercise.ipynb new file mode 100644 index 0000000..159f01a --- /dev/null +++ b/03_COSDD/bonus-exercise.ipynb @@ -0,0 +1,312 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonus exercise. Generating new images with COSDD\n", + "\n", + "As mentioned in the training.ipynb notebook, COSDD is a deep generative model that captures the structures and characteristics of our data. In this notebook, we'll see how accurately it can represent our training data, in both the signal and the noise. We'll do this by using the model to generate entirely new images. These will be images that look like the ones in our training data but don't actually exist. This is the same as how models like DALL-E can generate entirely new images." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import tifffile\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from COSDD import utils\n", + "from COSDD.models.hub import Hub\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1. Load trained model and clean and noisy data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.1.\n", + "\n", + "Load the model trained in the first notebook by entering your `model_name`, or alternatively, uncomment line 4 to load the pretrained model.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "model_name = ... # Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\"\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\")).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load the data\n", + "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "low_snr = tifffile.imread(lowsnr_path)\n", + "low_snr = low_snr[:, np.newaxis]\n", + "low_snr = torch.from_numpy(low_snr)\n", + "low_snr = low_snr.to(torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2. Generating new noise for a real noisy image\n", + "\n", + "First, we'll pass a noisy image to the VAE and generate a random sample from the AR decoder. This will give us another noisy image with the same underlying clean signal but a different random sample of noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`inp_image` (torch.Tensor): The real noisy image we're going to add a different random sample of noise to.
\n", + "`denoised` (torch.Tensor): The denoised version of `inp_image`.
\n", + "`noisy` (torch.Tensor): The same underlying signal as `inp_image` but a different sample of noise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inp_image = low_snr[:1, :, :512, :512].cuda()\n", + "reconstructions = hub.reconstruct(inp_image)\n", + "denoised = reconstructions[\"s_hat\"].cpu()\n", + "noisy = reconstructions[\"x_hat\"].cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(inp_image.cpu().numpy(), 0.1)\n", + "vmax = np.percentile(inp_image.cpu().numpy(), 99.9)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.2.\n", + "\n", + "Now we will look at the original noisy image and the generated noisy image. Adjust `top`, `bottom`, `left` and `right` to view different crops of the reconstructed image.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "top = 0\n", + "bottom = 512\n", + "left = 0\n", + "right = 512\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(inp_image[0][crop].cpu(), vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Original noisy image\")\n", + "ax[1].imshow(noisy[0][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Generated noisy image\")\n", + "ax[2].imshow(denoised[0][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"Denoised image\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The spatial correlation of the generated noise can be compared to that of the real noise to get an idea of how accurate the model is. Since we have the denoised version of the generated image, we can get a noise sample by just subtracting it from the noisy versions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "real_noise = low_snr[8, 0, 800:, 800:]\n", + "generated_noise = noisy[0, 0] - denoised[0, 0]\n", + "\n", + "real_ac = utils.autocorrelation(real_noise, max_lag=25)\n", + "generated_ac = utils.autocorrelation(generated_noise, max_lag=25)\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(12, 5))\n", + "ac1 = ax[0].imshow(real_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "ax[0].set_title(\"Autocorrelation of real noise\")\n", + "ax[0].set_xlabel(\"Horizontal lag\")\n", + "ax[0].set_ylabel(\"Vertical lag\")\n", + "ac2 = ax[1].imshow(generated_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "ax[1].set_title(\"Autocorrelation of generated noise\")\n", + "ax[1].set_xlabel(\"Horizontal lag\")\n", + "ax[1].set_ylabel(\"Vertical lag\")\n", + "\n", + "fig.colorbar(ac2, fraction=0.045)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3. Generating new images\n", + "\n", + "This time, we'll take a sample from the VAE's prior. This will be a latent variable containing information about a brand new signal. The signal decoder will take that latent variable and convert it into a clean image. The AR decoder will take the latent variable and create an image with the same clean image plus noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.3.\n", + "\n", + "Set the `n_imgs` variable below to decide how many images to generate. If you set it too high you'll get an out-of-memory error, but don't worry, just restart the kernel and run again with a lower value.\n", + "\n", + "Explore the images you generated in the second cell below. Look at the differences between them to see what aspects of the signal the model has learned to generate.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_imgs = 5 # Insert an integer here\n", + "generations = hub.sample_prior(n_imgs=n_imgs)\n", + "new_denoised = generations[\"s\"].cpu()\n", + "new_noisy = generations[\"x\"].cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 256\n", + "left = 0\n", + "right = 256\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n", + "ax[0].imshow(new_noisy[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Generated noisy image\")\n", + "ax[1].imshow(new_denoised[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Generated clean image\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Checkpoint 3\n", + "\n", + "In this notebook, we saw how the model you trained in the first notebook has learned to describe the data. We first added a new sample of noise to an existing noisy image. We then generated a clean image that looks like it could be from the training data but doesn't actually exist.
\n", + "You can now optionally return to section 3.1 to load a model that's been trained for much longer, otherwise, you've finished this module on COSDD.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/03_COSDD/bonus-solution-generation.ipynb b/03_COSDD/bonus-solution-generation.ipynb new file mode 100755 index 0000000..cd6d66c --- /dev/null +++ b/03_COSDD/bonus-solution-generation.ipynb @@ -0,0 +1,347 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonus exercise. Generating new images with COSDD\n", + "\n", + "As mentioned in the training.ipynb notebook, COSDD is a deep generative model that captures the structures and characteristics of our data. In this notebook, we'll see how accurately it can represent our training data, in both the signal and the noise. We'll do this by using the model to generate entirely new images. These will be images that look like the ones in our training data but don't actually exist. This is the same as how models like DALL-E can generate entirely new images." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import tifffile\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from COSDD import utils\n", + "from COSDD.models.hub import Hub\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1. Load trained model and clean and noisy data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.1.\n", + "\n", + "Load the model trained in the first notebook by entering your `model_name`, or alternatively, uncomment line 4 to load the pretrained model.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "model_name = ... # Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\"\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\")).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "model_name = \"mito-confocal\" # Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\"\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\")).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load the data\n", + "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "low_snr = tifffile.imread(lowsnr_path)\n", + "low_snr = low_snr[:, np.newaxis]\n", + "low_snr = torch.from_numpy(low_snr)\n", + "low_snr = low_snr.to(torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2. Generating new noise for a real noisy image\n", + "\n", + "First, we'll pass a noisy image to the VAE and generate a random sample from the AR decoder. This will give us another noisy image with the same underlying clean signal but a different random sample of noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`inp_image` (torch.Tensor): The real noisy image we're going to add a different random sample of noise to.
\n", + "`denoised` (torch.Tensor): The denoised version of `inp_image`.
\n", + "`noisy` (torch.Tensor): The same underlying signal as `inp_image` but a different sample of noise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inp_image = low_snr[:1, :, :512, :512].cuda()\n", + "reconstructions = hub.reconstruct(inp_image)\n", + "denoised = reconstructions[\"s_hat\"].cpu()\n", + "noisy = reconstructions[\"x_hat\"].cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(inp_image.cpu().numpy(), 0.1)\n", + "vmax = np.percentile(inp_image.cpu().numpy(), 99.9)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.2.\n", + "\n", + "Now we will look at the original noisy image and the generated noisy image. Adjust `top`, `bottom`, `left` and `right` to view different crops of the reconstructed image.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "top = 0\n", + "bottom = 512\n", + "left = 0\n", + "right = 512\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(inp_image[0][crop].cpu(), vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Original noisy image\")\n", + "ax[1].imshow(noisy[0][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Generated noisy image\")\n", + "ax[2].imshow(denoised[0][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"Denoised image\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The spatial correlation of the generated noise can be compared to that of the real noise to get an idea of how accurate the model is. Since we have the denoised version of the generated image, we can get a noise sample by just subtracting it from the noisy versions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "real_noise = low_snr[8, 0, 800:, 800:]\n", + "generated_noise = noisy[0, 0] - denoised[0, 0]\n", + "\n", + "real_ac = utils.autocorrelation(real_noise, max_lag=25)\n", + "generated_ac = utils.autocorrelation(generated_noise, max_lag=25)\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(12, 5))\n", + "ac1 = ax[0].imshow(real_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "ax[0].set_title(\"Autocorrelation of real noise\")\n", + "ax[0].set_xlabel(\"Horizontal lag\")\n", + "ax[0].set_ylabel(\"Vertical lag\")\n", + "ac2 = ax[1].imshow(generated_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "ax[1].set_title(\"Autocorrelation of generated noise\")\n", + "ax[1].set_xlabel(\"Horizontal lag\")\n", + "ax[1].set_ylabel(\"Vertical lag\")\n", + "\n", + "fig.colorbar(ac2, fraction=0.045)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3. Generating new images\n", + "\n", + "This time, we'll take a sample from the VAE's prior. This will be a latent variable containing information about a brand new signal. The signal decoder will take that latent variable and convert it into a clean image. The AR decoder will take the latent variable and create an image with the same clean image plus noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 3.3.\n", + "\n", + "Set the `n_imgs` variable below to decide how many images to generate. If you set it too high you'll get an out-of-memory error, but don't worry, just restart the kernel and run again with a lower value.\n", + "\n", + "Explore the images you generated in the second cell below. Look at the differences between them to see what aspects of the signal the model has learned to generate.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task", + "solution" + ] + }, + "outputs": [], + "source": [ + "n_imgs = ... # Insert an integer here\n", + "generations = hub.sample_prior(n_imgs=n_imgs)\n", + "new_denoised = generations[\"s\"].cpu()\n", + "new_noisy = generations[\"x\"].cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_imgs = 5 # Insert an integer here\n", + "generations = hub.sample_prior(n_imgs=n_imgs)\n", + "new_denoised = generations[\"s\"].cpu()\n", + "new_noisy = generations[\"x\"].cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 256\n", + "left = 0\n", + "right = 256\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n", + "ax[0].imshow(new_noisy[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Generated noisy image\")\n", + "ax[1].imshow(new_denoised[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Generated clean image\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Checkpoint 3\n", + "\n", + "In this notebook, we saw how the model you trained in the first notebook has learned to describe the data. We first added a new sample of noise to an existing noisy image. We then generated a clean image that looks like it could be from the training data but doesn't actually exist.
\n", + "You can now optionally return to section 3.1 to load a model that's been trained for much longer, otherwise, you've finished this module on COSDD.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/03_COSDD/exercise.ipynb b/03_COSDD/exercise.ipynb new file mode 100644 index 0000000..b781074 --- /dev/null +++ b/03_COSDD/exercise.ipynb @@ -0,0 +1,1038 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exercise 1. Training COSDD
\n", + "In this section, we will train a COSDD model to remove row correlated and signal-dependent imaging noise. \n", + "You will load noisy data and examine the noise for spatial correlation, then initialise a model and monitor its training.\n", + "Finally, you'll use the model to denoise the data.\n", + "\n", + "COSDD is a Ladder VAE with an autoregressive decoder, a type of deep generative model. Deep generative models are trained with the objective of capturing all the structures and characteristics present in a dataset, i.e., modelling the dataset. In our case the dataset will be a collection of noisy microscopy images. \n", + "\n", + "When COSDD is trained to model noisy images, it exploits differences between the structure of imaging noise and the structure of the clean signal to separate them, capturing each with different components of the model. Specifically, the noise will be captured by the autoregressive decoder and the signal will be captured by the VAE's latent variables. We can then feed an image into the model and sample a latent variable, which will describe the image's clean signal content. This latent variable is then fed through a second network, which was trained alongside the main VAE, to reveal an estimate of the denoised image." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "import tifffile\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import EarlyStopping\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "from COSDD import utils\n", + "from COSDD.models.lvae import LadderVAE\n", + "from COSDD.models.pixelcnn import PixelCNN\n", + "from COSDD.models.s_decoder import SDecoder\n", + "from COSDD.models.unet import UNet\n", + "from COSDD.models.hub import Hub\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1. Load data\n", + "\n", + "In this example we will be using the Mito Confocal dataset, provided by:
\n", + "Hagen, G.M., Bendesky, J., Machado, R., Nguyen, T.A., Kumar, T. and Ventura, J., 2021. Fluorescence microscopy datasets for training deep neural networks. GigaScience, 10(5), p.giab032." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You will have tried denoising this in the final section of the N2V exercise and hopefully noticed the horizontal artifacts. In this exercise, we'll train a model that can handle spatially correlated noise and won't leave behind artifacts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.1.\n", + "\n", + "The low signal-to-noise ratio data that we will be using in this exercise has been downloaded and stored as a tiff file at `./../data/mito-confocal-lowsnr.tif`. \n", + "\n", + "In the following cell, you'll load it and get it into a format suitable for training the denoiser.\n", + "\n", + "1. Use the function `tifffile.imread` to load the data as a numpy array.\n", + "2. Then use np.newaxis to add a channel axis. *Hint* The data is a stack of 2D images, so the channel axis should be the second dimension (dimension 1 if we start counting from zero).\n", + "3. Next, use `torch.from_numpy` to convert it into a pytorch tensor.\n", + "4. Lastly, convert the datatype to `torch.float32`.\n", + "\n", + "COSDD can handle 1-, 2- and 3-dimensional data, as long as it's loaded as a PyTorch tensor with a batch and channel dimension. For 1D data, it should have dimensions [Number of images, Channels, X], for 2D data: [Number of images, Channels, Y, X] and for 3D: [Number of images, Channels, Z, Y, X]. This applies even if the data has only one channel.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# load the data\n", + "low_snr = ...\n", + "low_snr = ...\n", + "low_snr = ...\n", + "low_snr = ...\n", + "\n", + "assert [*low_snr.size()] == [79, 1, 1024, 1024], \"Incorrect dimensions\"\n", + "assert low_snr.dtype == torch.float32, \"Incorrect data type\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2. Examine spatial correlation of the noise" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "COSDD can be applied to noise that is correlated along rows or columns of pixels (or not spatially correlated at all). \n", + "However, it cannot be applied to noise that is correlated along rows *and* columns of pixels.\n", + "Noise2Void on the other hand, is designed for noise that is not spatially correlated at all.\n", + "\n", + "When we say that the noise is spatially correlated, we mean that knowing the value of the noise in one pixel tells us something about the noise in other (usually nearby) pixels.\n", + "Specifically, positive correlatation between two pixels tells us that if the intensity of the noise value in one pixel is high, the intensity of the noise value in the other pixel is likely to be high.\n", + "Similarly, if one is low, the other is likely to be low.\n", + "Negative correlation between pixels means that a low noise intensity in one pixel is more likely if the intensity in the other is high, and vice versa.\n", + "\n", + "To examine an image's spatial correlation, we can create an autocorrelation plot. \n", + "The plot will have two axes, horizontal lag and vertical lag, and tells us what the correlation between a pair of pixels separated by a given horizontal and vertical lag is.\n", + "For example, if the square at a horizontal lag of 3 and a vertical lag of 6 is red, it means that if we picked any pixel in the image, then counted 3 pixels to the right and 6 pixels down, this pair of pixels are positively correlated.\n", + "Correlation is symmetric, so the same is true if we counted left or up." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Question 1.1.\n", + "\n", + "Below are three autocorrelation plots. The show how the noise is spatially correlated in three different examples of noise.\n", + "Identify which noise examples could be removed by:
\n", + "(a) COSDD
\n", + "(b) Noise2Void
\n", + "(c) neither\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.2.\n", + "\n", + "Now we will create an autocorrelation plot of the data we loaded.\n", + "To do this, we need a sample of pure noise.\n", + "This can be a patch of `low_snr` with no signal. \n", + "Adjust the values for `image_idx`, `top`, `bottom`, `left` and `right` to explore slices of the data and identify a suitable dark patch. \n", + "When decided, set the `dark_patch` in the following cell and pass it as an argument to `utils.autocorrelation`, then plot the result. \n", + "\n", + "*Hint: The bigger the dark patch, the more accurate our estimate of the spatial autocorrelation will be.*\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(low_snr, 1)\n", + "vmax = np.percentile(low_snr, 99)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### Explore slices of the data here\n", + "image_index = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (image_index, 0, slice(top, bottom), slice(left, right))\n", + "\n", + "plt.figure(figsize=(10, 10))\n", + "plt.imshow(low_snr[crop], vmin=vmin, vmax=vmax)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "### Define the crop of the dark image patch here\n", + "dark_image_index = ...\n", + "dark_top = ...\n", + "dark_bottom = ...\n", + "dark_left = ...\n", + "dark_right = ...\n", + "\n", + "dark_crop = (dark_image_index, 0, slice(dark_top, dark_bottom), slice(dark_left, dark_right))\n", + "dark_patch = low_snr[dark_crop]\n", + "\n", + "noise_ac = utils.autocorrelation(dark_patch, max_lag=25)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the autocorrelation\n", + "plt.figure()\n", + "plt.imshow(noise_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "plt.colorbar()\n", + "plt.title(\"Autocorrelation of the noise\")\n", + "plt.xlabel(\"Horizontal lag\")\n", + "plt.ylabel(\"Vertical lag\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this plot, all of the squares should be white, except for the top row. The autocorrelation of the square at (0, 0) will always be 1.0, as a pixel's value will always be perfectly correlated with itself. We define this type of noise as correlated along the x axis.\n", + "\n", + "To remove this type of noise, the autoregressive decoder of our VAE must have a receptive field spanning the x axis.\n", + "Note that if the data contained spatially *un*correlated noise, we could still remove it, as the decoder's receptive field will become redundant." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 1\n", + "Now that we're familiar with our data, we'll train a COSDD model to denoise it.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3. Create training and validation dataloaders" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The data will be fed to the model by two dataloaders, `train_loader` and `val_loader`, for the training and validation set respectively.
\n", + "In this example, 90% of images will be used for training and the remaining 10% for validation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`real_batch_size` (int) Number of images passed through the network at a time.
\n", + "`n_grad_batches` (int) Number of batches to pass through the network before updating parameters.
\n", + "`crop_size` (tuple(int)): The size of randomly cropped patches. Should be less than the dimensions of your images.
\n", + "`train_split` (0 < float < 1): Fraction of images to be used in the training set, with the remainder used for the validation set.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "real_batch_size = 4\n", + "n_grad_batches = 4\n", + "print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n", + "crop_size = (256, 256)\n", + "train_split = 0.9\n", + "\n", + "n_iters = np.prod(low_snr.shape[2:]) // np.prod(crop_size)\n", + "transform = utils.RandomCrop(crop_size)\n", + "\n", + "dataset = utils.TrainDataset(low_snr, n_iters=n_iters, transform=transform)\n", + "train_set, val_set = torch.utils.data.random_split(dataset, [train_split, 1-train_split])\n", + "train_loader = torch.utils.data.DataLoader(\n", + " train_set, batch_size=real_batch_size, shuffle=True, pin_memory=True, num_workers=7,\n", + ")\n", + "val_loader = torch.utils.data.DataLoader(\n", + " val_set, batch_size=real_batch_size, shuffle=False, pin_memory=True, num_workers=7,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4. Create the model\n", + "\n", + "The model we will train to denoise consists of four modules, with forth being the optional Direct Denoiser which we can train if we want to speed up inference. Each module is listed below with an explanation of their hyperparameters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "COSDD is a Variational Autoencoder (solid arrows) trained to model the distribution of noisy images $\\mathbf{x}$. \n", + "The autoregressive (AR) decoder models the noise component of the images, while the latent variable models only the clean signal component $\\mathbf{s}$.\n", + "In a second step (dashed arrows), the \\emph{signal decoder} is trained to map latent variables into image space, producing an estimate of the signal underlying $\\mathbf{x}$.\n", + "{\\bf b):}\n", + "To ensure that the decoder models only the imaging noise and the latent variables capture only the signal, the AR decoder's receptive field is modified.\n", + "In a full AR receptive field, each output pixel (red) is a function of all input pixels located above and to the left (blue). In our decoder's row-based AR receptive field, each output pixel is a function of input pixels located in the same row, which corresponds to the row-correlated structure of imaging noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`dimensions` (int): The dimensionality of the data. Can be 1, 2, or 3.\n", + "\n", + "`lvae` The ladder variational autoencoder that will output latent variables.
\n", + "* `s_code_channels` (int): Number of channels in outputted latent variable.\n", + "* `n_layers` (int): Number of levels in the ladder vae.\n", + "* `z_dims` (list(int)): List with the numer of latent space dimensions at each level of the hierarchy. List starts from the input/output level and works down.\n", + "* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't.\n", + "\n", + "`ar_decoder` The autoregressive decoder that will decode latent variables into a distribution over the input.
\n", + "* `kernel_size` (int): Length of 1D convolutional kernels.\n", + "* `noise_direction` (str): Axis along which noise is correlated: `\"x\"`, `\"y\"` or `\"z\"`. This needs to match the orientation of the noise structures we revealed in the autocorrelation plot in Task 1.2.\n", + "* `n_filters` (int): Number of feature channels.\n", + "* `n_gaussians` (int): Number of components in Gaussian mixture used to model data.\n", + "\n", + "`s_decoder` A decoder that will map the latent variables into image space, giving us a denoised image.
\n", + "* `n_filters` (int): The number of feature channels.
\n", + "\n", + "`direct_denoiser` The U-Net that can optionally be trained to predict the MMSE or MMAE of the denoised images. This will slow training slightly but massively speed up inference and is worthwile if you have an inference dataset in the gigabytes. See [this paper](https://arxiv.org/abs/2310.18116). Enable or disable the direct denoiser by setting `use_direct_denoiser` to `True` or `False`.\n", + "* `n_filters` (int): Feature channels at each level of UNet. Defaults to `s_code_channel`.\n", + "* `n_layers` (int): Number of levels in the UNet. Defaults to the number of levels in the `LadderVAE`.\n", + "* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't. Also defaults to match the `LadderVAE`.\n", + "* `loss_fn` (str): Whether to use `\"L1\"` or `\"L2\"` loss function to predict either the mean or pixel-wise median of denoised images respectively.\n", + "\n", + "`hub` The hub that will unify and train the above modules.\n", + "* `n_grad_batches` (int): Number of batches to accumulate gradients for before updating weights of all models. If the real batch or random crop size has been reduced to lower memory consumption, increase this value for the effective batch size to stay the same.\n", + "* `checkpointed` (bool): Whether to use activation checkpointing during training. This reduces memory consumption but increases training time. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.3.\n", + "\n", + "Most hyperparameters have been set to recommended values for a small sized model. The three that have been left blank are `dimensions`, `noise_direction` under the `ar_decoder`, and `use_direct_denoiser`. Use the above description of what each hyperparameter means to determine the best value for each of these.\n", + "\n", + "*Hint: In this notebook we're using 2D data*
\n", + "*Hint: enabling the Direct Denoiser will give us additional results to look at in the next notebook.*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "dimensions = ... ### Insert a value here\n", + "s_code_channels = 32\n", + "\n", + "n_layers = 6\n", + "z_dims = [s_code_channels // 2] * n_layers\n", + "downsampling = [1] * n_layers\n", + "lvae = LadderVAE(\n", + " colour_channels=low_snr.shape[1],\n", + " img_size=crop_size,\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " z_dims=z_dims,\n", + " downsampling=downsampling,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "ar_decoder = PixelCNN(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " kernel_size=5,\n", + " noise_direction=... ### Insert a value here\n", + " n_filters=32,\n", + " n_layers=4,\n", + " n_gaussians=4,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "s_decoder = SDecoder(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "use_direct_denoiser = ... ### Insert a value here\n", + "if use_direct_denoiser:\n", + " direct_denoiser = UNet(\n", + " colour_channels=low_snr.shape[1],\n", + " n_filters=s_code_channels,\n", + " n_layers=n_layers,\n", + " downsampling=downsampling,\n", + " loss_fn=\"L2\",\n", + " dimensions=dimensions,\n", + " )\n", + "else:\n", + " direct_denoiser = None\n", + "\n", + "hub = Hub(\n", + " vae=lvae,\n", + " ar_decoder=ar_decoder,\n", + " s_decoder=s_decoder,\n", + " direct_denoiser=direct_denoiser,\n", + " data_mean=low_snr.mean(),\n", + " data_std=low_snr.std(),\n", + " n_grad_batches=n_grad_batches,\n", + " checkpointed=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.5. Train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.4.\n", + "\n", + "Open Tensorboard (check Task 3 in 01_CARE) to monitor training.\n", + "This model is unlike the previous two because it has more than one loss curve.\n", + "The cell below describes how to interpret each one.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Tensorboard metrics\n", + "\n", + "In the SCALARS tab, there will be 4 metrics to track (5 if direct denoiser is enabled). These are:
\n", + "1. `kl_loss` The Kullback-Leibler divergence between the VAE's approximate posterior and its prior. This can be thought of as a measure of how much information about the input image is going into the VAE's latent variables. We want information about the input's underlying clean signal to go into the latent variables, so this metric shouldn't go all the way to zero. Instead, it can typically go either up or down during training before plateauing.
\n", + "2. `reconstruction_loss` The negative log-likelihood of the AR decoder's predicted distribution given the input data. This is how accurately the AR decoder is able to predict the input. This value can go below zero and should decrease throughout training before plateauing.
\n", + "3. `elbo` The Evidence Lower Bound, which is the total loss of the main VAE. This is the sum of the kl and reconstruction loss and should decrease throughout training before plateauing.
\n", + "4. `sd_loss` The mean squared error between the noisy image and the image predicted by the signal decoder. This metric should steadily decrease towards zero without ever reaching it. Sometimes the loss will not go down for the first few epochs because its input (produced by the VAE) is rapidly changing. This is ok and the loss should start to decrease when the VAE stabilises.
\n", + "5. `dd_loss` The mean squared error between the output of the direct denoiser and the clean images predicted by the signal decoder. This will only be present if `use_direct_denoiser` is set to `True`. The metric should steadily decrease towards zero without ever reaching it, but may be unstable at the start of training as its targets (produced by the signal decoder) are rapidly changing.\n", + "\n", + "There will also be an IMAGES tab. This shows noisy input images from the validation set and some outputs. These will be two randomly sampled denoised images (sample 1 and sample 2), the average of ten denoised images (mmse) and if the direct denoiser is enabled, its output (direct estimate).\n", + "\n", + "If noise has not been fully removed from the output images, try increasing `n_gaussians` argument of the AR decoder. This will give it more flexibility to model complex noise characteristics. However, setting the value too high can lead to unstable training. Typically, values from 3 to 5 work best.\n", + "\n", + "Note that the trainer is set to train for only 10 minutes in this example. Remove the line with `max_time` to train fully." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.5.\n", + "\n", + "Now the model is ready to start training. Give the model a sensible name by setting `model_name` to a string, then run the following cells.\n", + "\n", + "The `max_time` parameter in the cell below means we'll only train the model for 10 minutes, just to get idea of what to expect. In the future, to remove the time restriction, the `max_time` parameter can be set to `None`.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`model_name` (str): Should be set to something appropriate so that the trained parameters can be used later for inference.
\n", + "`max_epochs` (int): The number of training epochs.
\n", + "`patience` (int): If the validation loss has plateaued for this many epochs, training will stop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "model_name = ... ### Insert a value here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "logger = TensorBoardLogger(checkpoint_path)\n", + "\n", + "max_epochs = 1000\n", + "max_time = \"00:00:10:00\"\n", + "patience = 100\n", + "\n", + "trainer = pl.Trainer(\n", + " logger=logger,\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " max_time=max_time, # Remove this time limit to train the model fully\n", + " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", + " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(hub, train_loader, val_loader)\n", + "trainer.save_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 2\n", + "We've now trained a COSDD model to denoise our data. Continue to the next part to use it to get some results.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exercise 2. Inference with COSDD" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1. Load test data\n", + "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "n_test_images = 5\n", + "# load the data\n", + "test_set = tifffile.imread(lowsnr_path)\n", + "test_set = test_set[:n_test_images, np.newaxis]\n", + "test_set = torch.from_numpy(test_set)\n", + "test_set = test_set.to(torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As with training, data should be a `torch.Tensor` with dimensions: [Number of images, Channels, Z | Y | X] with data type float32." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part 2. Create prediction dataloader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`predict_batch_size` (int): Number of denoised images to produce at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predict_batch_size = 1\n", + "\n", + "predict_set = utils.PredictDataset(test_set)\n", + "predict_loader = torch.utils.data.DataLoader(\n", + " predict_set,\n", + " batch_size=predict_batch_size,\n", + " shuffle=False,\n", + " pin_memory=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3. Load trained model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.1.\n", + "\n", + "Our model was only trained for 10 minutes. This is long enough to get some denoising results, but a model trained for longer would do better. In the cell below, load the trained model by recalling the value you gave for `model_name`. Then procede through the notebook to look at how well it performs. \n", + "\n", + "Once you reach the end of the notebook, return to this cell to load a model that has been trained for 3.5 hours by uncommenting line 4, then run the notebook again to see how much difference the extra training time makes. \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = ... ### Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "\n", + "predictor = pl.Trainer(\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " enable_progress_bar=False,\n", + " enable_checkpointing=False,\n", + " logger=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"mito-confocal\" ### Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "\n", + "predictor = pl.Trainer(\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " enable_progress_bar=False,\n", + " enable_checkpointing=False,\n", + " logger=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4. Denoise\n", + "In this section, we will look at how COSDD does inference.
\n", + "\n", + "The model denoises images randomly, giving us a different output each time. First, we will compare seven randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.1 Random sampling \n", + "First, we will denoise each image seven times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Z | Y | X] where different denoised samples for the same image are stored along sample index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = 7\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.2.\n", + "\n", + "Here, we'll look at the original noisy image and the seven denoised estimates. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop. Use this section to really explore the results. Compare high intensity reigons to low intensity reigons, zoom in and out and spot the differences between the different samples. \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(test_set.numpy(), 1)\n", + "vmax = np.percentile(test_set.numpy(), 99)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(2, 4, figsize=(16, 8))\n", + "ax[0, 0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0, 0].set_title(\"Input\")\n", + "for i in range(n_samples):\n", + " ax[(i + 1) // 4, (i + 1) % 4].imshow(\n", + " samples[img_idx][i][crop], vmin=vmin, vmax=vmax\n", + " )\n", + " ax[(i + 1) // 4, (i + 1) % 4].set_title(f\"Sample {i+1}\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.2 MMSE estimate\n", + "\n", + "In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.3.\n", + "Set `n_samples` to 100 to average 100 images, or a different value to average a different number. Then visually inspeect the results. Examine how the MMSE result differs from the random sample.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = ... ### Insert an integer here\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()\n", + "MMSEs = torch.mean(samples, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = 100 ### Insert an integer here\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()\n", + "MMSEs = torch.mean(samples, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Input\")\n", + "ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Sample\")\n", + "ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"MMSE\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The MMSE will usually be closer to the reference than an individual sample and would score a higher PSNR, although it will also be blurrier." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.3 Direct denoising\n", + "Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.4.\n", + "\n", + "Did you enable the direct denoiser in the previous notebook? If so, set `use_direct_denoiser` to `True` to use the Direct Denoiser for inference. If not, go back to section 2.3 to load the pretrained model and return here. \n", + "\n", + "Notice how much quicker the direct denoiser is than generating the MMSE results. Visually inspect and explore the results in the same way as before, notice how similar the direct estimate and MMSE estimate are.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = ... ### Insert a boolean here\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "direct = predictor.predict(hub, predict_loader)\n", + "direct = torch.cat(direct, dim=0).half()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = True ### Insert a boolean here\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "direct = predictor.predict(hub, predict_loader)\n", + "direct = torch.cat(direct, dim=0).half()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Input\")\n", + "ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Direct\")\n", + "ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"MMSE\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.5. Incorrect receptive field\n", + "\n", + "We've now trained a model and used it to remove structured noise from our data. Before moving onto the next notebook, we'll look at what happens when a COSDD model is trained without considering the noise structures present. \n", + "\n", + "COSDD is able to separate imaging noise from clean signal because its autoregressive decoder has a receptive field that spans pixels containing correlated noise, i.e., the row or column of pixels. If its receptive field did not contain pixels with correlated noise, it would not be able to model them and they would be captured by the VAE's latent variables. To demonstrate this, the image below shows a Direct and MMSE estimate of a denoised image where the autoregressive decoder's receptive field was incorrectly set to vertical, leaving it unable to model horizontal noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 3\n", + "\n", + "We've completed the process of training and applying a COSDD model for denoising, but there's still more it can do. Optionally continue to the bonus notebook, bonus-exercise-generation.ipynb, to see how the model of the data can be used to generate new clean and noisy images.\n", + "\n", + "Otherwise, continue to 04_DenoiSplit.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/03_COSDD/resources/ac-question.png b/03_COSDD/resources/ac-question.png new file mode 100644 index 0000000..f932d49 Binary files /dev/null and b/03_COSDD/resources/ac-question.png differ diff --git a/03_COSDD/resources/explainer.png b/03_COSDD/resources/explainer.png new file mode 100644 index 0000000..12d7c9f Binary files /dev/null and b/03_COSDD/resources/explainer.png differ diff --git a/03_COSDD/resources/matrix.png b/03_COSDD/resources/matrix.png new file mode 100755 index 0000000..5bdbb22 Binary files /dev/null and b/03_COSDD/resources/matrix.png differ diff --git a/03_COSDD/resources/penicillium_ynm.png b/03_COSDD/resources/penicillium_ynm.png new file mode 100755 index 0000000..6d998c8 Binary files /dev/null and b/03_COSDD/resources/penicillium_ynm.png differ diff --git a/03_COSDD/solution.ipynb b/03_COSDD/solution.ipynb new file mode 100755 index 0000000..7aae351 --- /dev/null +++ b/03_COSDD/solution.ipynb @@ -0,0 +1,1192 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exercise 1. Training COSDD
\n", + "In this section, we will train a COSDD model to remove row correlated and signal-dependent imaging noise. \n", + "You will load noisy data and examine the noise for spatial correlation, then initialise a model and monitor its training.\n", + "Finally, you'll use the model to denoise the data.\n", + "\n", + "COSDD is a Ladder VAE with an autoregressive decoder, a type of deep generative model. Deep generative models are trained with the objective of capturing all the structures and characteristics present in a dataset, i.e., modelling the dataset. In our case the dataset will be a collection of noisy microscopy images. \n", + "\n", + "When COSDD is trained to model noisy images, it exploits differences between the structure of imaging noise and the structure of the clean signal to separate them, capturing each with different components of the model. Specifically, the noise will be captured by the autoregressive decoder and the signal will be captured by the VAE's latent variables. We can then feed an image into the model and sample a latent variable, which will describe the image's clean signal content. This latent variable is then fed through a second network, which was trained alongside the main VAE, to reveal an estimate of the denoised image." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import torch\n", + "import tifffile\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import EarlyStopping\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "from COSDD import utils\n", + "from COSDD.models.lvae import LadderVAE\n", + "from COSDD.models.pixelcnn import PixelCNN\n", + "from COSDD.models.s_decoder import SDecoder\n", + "from COSDD.models.unet import UNet\n", + "from COSDD.models.hub import Hub\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.cuda.is_available()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1. Load data\n", + "\n", + "In this example we will be using the Mito Confocal dataset, provided by:
\n", + "Hagen, G.M., Bendesky, J., Machado, R., Nguyen, T.A., Kumar, T. and Ventura, J., 2021. Fluorescence microscopy datasets for training deep neural networks. GigaScience, 10(5), p.giab032." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You will have tried denoising this in the final section of the N2V exercise and hopefully noticed the horizontal artifacts. In this exercise, we'll train a model that can handle spatially correlated noise and won't leave behind artifacts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.1.\n", + "\n", + "The low signal-to-noise ratio data that we will be using in this exercise has been downloaded and stored as a tiff file at `./../data/mito-confocal-lowsnr.tif`. \n", + "\n", + "In the following cell, you'll load it and get it into a format suitable for training the denoiser.\n", + "\n", + "1. Use the function `tifffile.imread` to load the data as a numpy array.\n", + "2. Then use np.newaxis to add a channel axis. *Hint* The data is a stack of 2D images, so the channel axis should be the second dimension (dimension 1 if we start counting from zero).\n", + "3. Next, use `torch.from_numpy` to convert it into a pytorch tensor.\n", + "4. Lastly, convert the datatype to `torch.float32`.\n", + "\n", + "COSDD can handle 1-, 2- and 3-dimensional data, as long as it's loaded as a PyTorch tensor with a batch and channel dimension. For 1D data, it should have dimensions [Number of images, Channels, X], for 2D data: [Number of images, Channels, Y, X] and for 3D: [Number of images, Channels, Z, Y, X]. This applies even if the data has only one channel.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "# load the data\n", + "low_snr = ...\n", + "low_snr = ...\n", + "low_snr = ...\n", + "low_snr = ...\n", + "\n", + "assert [*low_snr.size()] == [79, 1, 1024, 1024], \"Incorrect dimensions\"\n", + "assert low_snr.dtype == torch.float32, \"Incorrect data type\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# load the data\n", + "low_snr = tifffile.imread(\"./../data/mito-confocal-lowsnr.tif\")\n", + "low_snr = low_snr[:, np.newaxis]\n", + "low_snr = torch.from_numpy(low_snr)\n", + "low_snr = low_snr.to(torch.float32)\n", + "\n", + "assert [*low_snr.size()] == [79, 1, 1024, 1024], \"Incorrect dimensions\"\n", + "assert low_snr.dtype == torch.float32, \"Incorrect data type\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2. Examine spatial correlation of the noise" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "COSDD can be applied to noise that is correlated along rows or columns of pixels (or not spatially correlated at all). \n", + "However, it cannot be applied to noise that is correlated along rows *and* columns of pixels.\n", + "Noise2Void on the other hand, is designed for noise that is not spatially correlated at all.\n", + "\n", + "When we say that the noise is spatially correlated, we mean that knowing the value of the noise in one pixel tells us something about the noise in other (usually nearby) pixels.\n", + "Specifically, positive correlatation between two pixels tells us that if the intensity of the noise value in one pixel is high, the intensity of the noise value in the other pixel is likely to be high.\n", + "Similarly, if one is low, the other is likely to be low.\n", + "Negative correlation between pixels means that a low noise intensity in one pixel is more likely if the intensity in the other is high, and vice versa.\n", + "\n", + "To examine an image's spatial correlation, we can create an autocorrelation plot. \n", + "The plot will have two axes, horizontal lag and vertical lag, and tells us what the correlation between a pair of pixels separated by a given horizontal and vertical lag is.\n", + "For example, if the square at a horizontal lag of 3 and a vertical lag of 6 is red, it means that if we picked any pixel in the image, then counted 3 pixels to the right and 6 pixels down, this pair of pixels are positively correlated.\n", + "Correlation is symmetric, so the same is true if we counted left or up." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Question 1.1.\n", + "\n", + "Below are three autocorrelation plots. The show how the noise is spatially correlated in three different examples of noise.\n", + "Identify which noise examples could be removed by:
\n", + "(a) COSDD
\n", + "(b) Noise2Void
\n", + "(c) neither\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "solution" + ] + }, + "source": [ + "1: COSDD and Noise2Void
\n", + "2: COSDD
\n", + "3: Neither" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.2.\n", + "\n", + "Now we will create an autocorrelation plot of the data we loaded.\n", + "To do this, we need a sample of pure noise.\n", + "This can be a patch of `low_snr` with no signal. \n", + "Adjust the values for `image_idx`, `top`, `bottom`, `left` and `right` to explore slices of the data and identify a suitable dark patch. \n", + "When decided, set the `dark_patch` in the following cell and pass it as an argument to `utils.autocorrelation`, then plot the result. \n", + "\n", + "*Hint: The bigger the dark patch, the more accurate our estimate of the spatial autocorrelation will be.*\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(low_snr, 1)\n", + "vmax = np.percentile(low_snr, 99)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### Explore slices of the data here\n", + "image_index = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (image_index, 0, slice(top, bottom), slice(left, right))\n", + "\n", + "plt.figure(figsize=(10, 10))\n", + "plt.imshow(low_snr[crop], vmin=vmin, vmax=vmax)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "### Define the crop of the dark image patch here\n", + "dark_image_index = ...\n", + "dark_top = ...\n", + "dark_bottom = ...\n", + "dark_left = ...\n", + "dark_right = ...\n", + "\n", + "dark_crop = (dark_image_index, 0, slice(dark_top, dark_bottom), slice(dark_left, dark_right))\n", + "dark_patch = low_snr[dark_crop]\n", + "\n", + "noise_ac = utils.autocorrelation(dark_patch, max_lag=25)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "### Define the crop of the dark image patch here\n", + "dark_img_idx = 8\n", + "dark_top = 800\n", + "dark_bottom = 1024\n", + "dark_left = 800\n", + "dark_right = 1024\n", + "\n", + "dark_crop = (dark_img_idx, 0, slice(dark_top, dark_bottom), slice(dark_left, dark_right))\n", + "dark_patch = low_snr[dark_crop]\n", + "\n", + "noise_ac = utils.autocorrelation(dark_patch, max_lag=25)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot the autocorrelation\n", + "plt.figure()\n", + "plt.imshow(noise_ac, cmap=\"seismic\", vmin=-1, vmax=1)\n", + "plt.colorbar()\n", + "plt.title(\"Autocorrelation of the noise\")\n", + "plt.xlabel(\"Horizontal lag\")\n", + "plt.ylabel(\"Vertical lag\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this plot, all of the squares should be white, except for the top row. The autocorrelation of the square at (0, 0) will always be 1.0, as a pixel's value will always be perfectly correlated with itself. We define this type of noise as correlated along the x axis.\n", + "\n", + "To remove this type of noise, the autoregressive decoder of our VAE must have a receptive field spanning the x axis.\n", + "Note that if the data contained spatially *un*correlated noise, we could still remove it, as the decoder's receptive field will become redundant." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 1\n", + "Now that we're familiar with our data, we'll train a COSDD model to denoise it.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3. Create training and validation dataloaders" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The data will be fed to the model by two dataloaders, `train_loader` and `val_loader`, for the training and validation set respectively.
\n", + "In this example, 90% of images will be used for training and the remaining 10% for validation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`real_batch_size` (int) Number of images passed through the network at a time.
\n", + "`n_grad_batches` (int) Number of batches to pass through the network before updating parameters.
\n", + "`crop_size` (tuple(int)): The size of randomly cropped patches. Should be less than the dimensions of your images.
\n", + "`train_split` (0 < float < 1): Fraction of images to be used in the training set, with the remainder used for the validation set.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "real_batch_size = 4\n", + "n_grad_batches = 4\n", + "print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n", + "crop_size = (256, 256)\n", + "train_split = 0.9\n", + "\n", + "n_iters = np.prod(low_snr.shape[2:]) // np.prod(crop_size)\n", + "transform = utils.RandomCrop(crop_size)\n", + "\n", + "dataset = utils.TrainDataset(low_snr, n_iters=n_iters, transform=transform)\n", + "train_set, val_set = torch.utils.data.random_split(dataset, [train_split, 1-train_split])\n", + "train_loader = torch.utils.data.DataLoader(\n", + " train_set, batch_size=real_batch_size, shuffle=True, pin_memory=True, num_workers=7,\n", + ")\n", + "val_loader = torch.utils.data.DataLoader(\n", + " val_set, batch_size=real_batch_size, shuffle=False, pin_memory=True, num_workers=7,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4. Create the model\n", + "\n", + "The model we will train to denoise consists of four modules, with forth being the optional Direct Denoiser which we can train if we want to speed up inference. Each module is listed below with an explanation of their hyperparameters." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "COSDD is a Variational Autoencoder (solid arrows) trained to model the distribution of noisy images $\\mathbf{x}$. \n", + "The autoregressive (AR) decoder models the noise component of the images, while the latent variable models only the clean signal component $\\mathbf{s}$.\n", + "In a second step (dashed arrows), the \\emph{signal decoder} is trained to map latent variables into image space, producing an estimate of the signal underlying $\\mathbf{x}$.\n", + "{\\bf b):}\n", + "To ensure that the decoder models only the imaging noise and the latent variables capture only the signal, the AR decoder's receptive field is modified.\n", + "In a full AR receptive field, each output pixel (red) is a function of all input pixels located above and to the left (blue). In our decoder's row-based AR receptive field, each output pixel is a function of input pixels located in the same row, which corresponds to the row-correlated structure of imaging noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`dimensions` (int): The dimensionality of the data. Can be 1, 2, or 3.\n", + "\n", + "`lvae` The ladder variational autoencoder that will output latent variables.
\n", + "* `s_code_channels` (int): Number of channels in outputted latent variable.\n", + "* `n_layers` (int): Number of levels in the ladder vae.\n", + "* `z_dims` (list(int)): List with the numer of latent space dimensions at each level of the hierarchy. List starts from the input/output level and works down.\n", + "* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't.\n", + "\n", + "`ar_decoder` The autoregressive decoder that will decode latent variables into a distribution over the input.
\n", + "* `kernel_size` (int): Length of 1D convolutional kernels.\n", + "* `noise_direction` (str): Axis along which noise is correlated: `\"x\"`, `\"y\"` or `\"z\"`. This needs to match the orientation of the noise structures we revealed in the autocorrelation plot in Task 1.2.\n", + "* `n_filters` (int): Number of feature channels.\n", + "* `n_gaussians` (int): Number of components in Gaussian mixture used to model data.\n", + "\n", + "`s_decoder` A decoder that will map the latent variables into image space, giving us a denoised image.
\n", + "* `n_filters` (int): The number of feature channels.
\n", + "\n", + "`direct_denoiser` The U-Net that can optionally be trained to predict the MMSE or MMAE of the denoised images. This will slow training slightly but massively speed up inference and is worthwile if you have an inference dataset in the gigabytes. See [this paper](https://arxiv.org/abs/2310.18116). Enable or disable the direct denoiser by setting `use_direct_denoiser` to `True` or `False`.\n", + "* `n_filters` (int): Feature channels at each level of UNet. Defaults to `s_code_channel`.\n", + "* `n_layers` (int): Number of levels in the UNet. Defaults to the number of levels in the `LadderVAE`.\n", + "* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't. Also defaults to match the `LadderVAE`.\n", + "* `loss_fn` (str): Whether to use `\"L1\"` or `\"L2\"` loss function to predict either the mean or pixel-wise median of denoised images respectively.\n", + "\n", + "`hub` The hub that will unify and train the above modules.\n", + "* `n_grad_batches` (int): Number of batches to accumulate gradients for before updating weights of all models. If the real batch or random crop size has been reduced to lower memory consumption, increase this value for the effective batch size to stay the same.\n", + "* `checkpointed` (bool): Whether to use activation checkpointing during training. This reduces memory consumption but increases training time. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.3.\n", + "\n", + "Most hyperparameters have been set to recommended values for a small sized model. The three that have been left blank are `dimensions`, `noise_direction` under the `ar_decoder`, and `use_direct_denoiser`. Use the above description of what each hyperparameter means to determine the best value for each of these.\n", + "\n", + "*Hint: In this notebook we're using 2D data*
\n", + "*Hint: enabling the Direct Denoiser will give us additional results to look at in the next notebook.*\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "dimensions = ... ### Insert a value here\n", + "s_code_channels = 32\n", + "\n", + "n_layers = 6\n", + "z_dims = [s_code_channels // 2] * n_layers\n", + "downsampling = [1] * n_layers\n", + "lvae = LadderVAE(\n", + " colour_channels=low_snr.shape[1],\n", + " img_size=crop_size,\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " z_dims=z_dims,\n", + " downsampling=downsampling,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "ar_decoder = PixelCNN(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " kernel_size=5,\n", + " noise_direction=... ### Insert a value here\n", + " n_filters=32,\n", + " n_layers=4,\n", + " n_gaussians=4,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "s_decoder = SDecoder(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "use_direct_denoiser = ... ### Insert a value here\n", + "if use_direct_denoiser:\n", + " direct_denoiser = UNet(\n", + " colour_channels=low_snr.shape[1],\n", + " n_filters=s_code_channels,\n", + " n_layers=n_layers,\n", + " downsampling=downsampling,\n", + " loss_fn=\"L2\",\n", + " dimensions=dimensions,\n", + " )\n", + "else:\n", + " direct_denoiser = None\n", + "\n", + "hub = Hub(\n", + " vae=lvae,\n", + " ar_decoder=ar_decoder,\n", + " s_decoder=s_decoder,\n", + " direct_denoiser=direct_denoiser,\n", + " data_mean=low_snr.mean(),\n", + " data_std=low_snr.std(),\n", + " n_grad_batches=n_grad_batches,\n", + " checkpointed=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "dimensions = 2 ### Insert a value here\n", + "s_code_channels = 32\n", + "\n", + "n_layers = 6\n", + "z_dims = [s_code_channels // 2] * n_layers\n", + "downsampling = [1] * n_layers\n", + "lvae = LadderVAE(\n", + " colour_channels=low_snr.shape[1],\n", + " img_size=crop_size,\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " z_dims=z_dims,\n", + " downsampling=downsampling,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "ar_decoder = PixelCNN(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " kernel_size=5,\n", + " noise_direction=\"x\", ### Insert a value here\n", + " n_filters=32,\n", + " n_layers=4,\n", + " n_gaussians=4,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "s_decoder = SDecoder(\n", + " colour_channels=low_snr.shape[1],\n", + " s_code_channels=s_code_channels,\n", + " n_filters=s_code_channels,\n", + " dimensions=dimensions,\n", + ")\n", + "\n", + "use_direct_denoiser = True ### Insert a value here\n", + "if use_direct_denoiser:\n", + " direct_denoiser = UNet(\n", + " colour_channels=low_snr.shape[1],\n", + " n_filters=s_code_channels,\n", + " n_layers=n_layers,\n", + " downsampling=downsampling,\n", + " loss_fn=\"L2\",\n", + " dimensions=dimensions,\n", + " )\n", + "else:\n", + " direct_denoiser = None\n", + "\n", + "hub = Hub(\n", + " vae=lvae,\n", + " ar_decoder=ar_decoder,\n", + " s_decoder=s_decoder,\n", + " direct_denoiser=direct_denoiser,\n", + " data_mean=low_snr.mean(),\n", + " data_std=low_snr.std(),\n", + " n_grad_batches=n_grad_batches,\n", + " checkpointed=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.5. Train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.4.\n", + "\n", + "Open Tensorboard (check Task 3 in 01_CARE) to monitor training.\n", + "This model is unlike the previous two because it has more than one loss curve.\n", + "The cell below describes how to interpret each one.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Tensorboard metrics\n", + "\n", + "In the SCALARS tab, there will be 4 metrics to track (5 if direct denoiser is enabled). These are:
\n", + "1. `kl_loss` The Kullback-Leibler divergence between the VAE's approximate posterior and its prior. This can be thought of as a measure of how much information about the input image is going into the VAE's latent variables. We want information about the input's underlying clean signal to go into the latent variables, so this metric shouldn't go all the way to zero. Instead, it can typically go either up or down during training before plateauing.
\n", + "2. `reconstruction_loss` The negative log-likelihood of the AR decoder's predicted distribution given the input data. This is how accurately the AR decoder is able to predict the input. This value can go below zero and should decrease throughout training before plateauing.
\n", + "3. `elbo` The Evidence Lower Bound, which is the total loss of the main VAE. This is the sum of the kl and reconstruction loss and should decrease throughout training before plateauing.
\n", + "4. `sd_loss` The mean squared error between the noisy image and the image predicted by the signal decoder. This metric should steadily decrease towards zero without ever reaching it. Sometimes the loss will not go down for the first few epochs because its input (produced by the VAE) is rapidly changing. This is ok and the loss should start to decrease when the VAE stabilises.
\n", + "5. `dd_loss` The mean squared error between the output of the direct denoiser and the clean images predicted by the signal decoder. This will only be present if `use_direct_denoiser` is set to `True`. The metric should steadily decrease towards zero without ever reaching it, but may be unstable at the start of training as its targets (produced by the signal decoder) are rapidly changing.\n", + "\n", + "There will also be an IMAGES tab. This shows noisy input images from the validation set and some outputs. These will be two randomly sampled denoised images (sample 1 and sample 2), the average of ten denoised images (mmse) and if the direct denoiser is enabled, its output (direct estimate).\n", + "\n", + "If noise has not been fully removed from the output images, try increasing `n_gaussians` argument of the AR decoder. This will give it more flexibility to model complex noise characteristics. However, setting the value too high can lead to unstable training. Typically, values from 3 to 5 work best.\n", + "\n", + "Note that the trainer is set to train for only 10 minutes in this example. Remove the line with `max_time` to train fully." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 1.5.\n", + "\n", + "Now the model is ready to start training. Give the model a sensible name by setting `model_name` to a string, then run the following cells.\n", + "\n", + "The `max_time` parameter in the cell below means we'll only train the model for 10 minutes, just to get idea of what to expect. In the future, to remove the time restriction, the `max_time` parameter can be set to `None`.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`model_name` (str): Should be set to something appropriate so that the trained parameters can be used later for inference.
\n", + "`max_epochs` (int): The number of training epochs.
\n", + "`patience` (int): If the validation loss has plateaued for this many epochs, training will stop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "model_name = ... ### Insert a value here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "logger = TensorBoardLogger(checkpoint_path)\n", + "\n", + "max_epochs = 1000\n", + "max_time = \"00:00:10:00\"\n", + "patience = 100\n", + "\n", + "trainer = pl.Trainer(\n", + " logger=logger,\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " max_time=max_time, # Remove this time limit to train the model fully\n", + " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", + " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "model_name = \"mito-confocal\" ### Insert a value here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "logger = TensorBoardLogger(checkpoint_path)\n", + "\n", + "max_epochs = 1000\n", + "max_time = \"00:00:10:00\"\n", + "patience = 100\n", + "\n", + "trainer = pl.Trainer(\n", + " logger=logger,\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " max_time=max_time, # Remove this time limit to train the model fully\n", + " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", + " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(hub, train_loader, val_loader)\n", + "trainer.save_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 2\n", + "We've now trained a COSDD model to denoise our data. Continue to the next part to use it to get some results.\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exercise 2. Inference with COSDD" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1. Load test data\n", + "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", + "n_test_images = 5\n", + "# load the data\n", + "test_set = tifffile.imread(lowsnr_path)\n", + "test_set = test_set[:n_test_images, np.newaxis]\n", + "test_set = torch.from_numpy(test_set)\n", + "test_set = test_set.to(torch.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As with training, data should be a `torch.Tensor` with dimensions: [Number of images, Channels, Z | Y | X] with data type float32." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Part 2. Create prediction dataloader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`predict_batch_size` (int): Number of denoised images to produce at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predict_batch_size = 1\n", + "\n", + "predict_set = utils.PredictDataset(test_set)\n", + "predict_loader = torch.utils.data.DataLoader(\n", + " predict_set,\n", + " batch_size=predict_batch_size,\n", + " shuffle=False,\n", + " pin_memory=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3. Load trained model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.1.\n", + "\n", + "Our model was only trained for 10 minutes. This is long enough to get some denoising results, but a model trained for longer would do better. In the cell below, load the trained model by recalling the value you gave for `model_name`. Then procede through the notebook to look at how well it performs. \n", + "\n", + "Once you reach the end of the notebook, return to this cell to load a model that has been trained for 3.5 hours by uncommenting line 4, then run the notebook again to see how much difference the extra training time makes. \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = ... ### Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "\n", + "predictor = pl.Trainer(\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " enable_progress_bar=False,\n", + " enable_checkpointing=False,\n", + " logger=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"mito-confocal\" ### Insert a string here\n", + "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", + "\n", + "# checkpoint_path = \"checkpoints/mito-confocal-pretrained\" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model\n", + "\n", + "hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, \"final_model.ckpt\"))\n", + "\n", + "predictor = pl.Trainer(\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " enable_progress_bar=False,\n", + " enable_checkpointing=False,\n", + " logger=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4. Denoise\n", + "In this section, we will look at how COSDD does inference.
\n", + "\n", + "The model denoises images randomly, giving us a different output each time. First, we will compare seven randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.1 Random sampling \n", + "First, we will denoise each image seven times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Z | Y | X] where different denoised samples for the same image are stored along sample index." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = 7\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.2.\n", + "\n", + "Here, we'll look at the original noisy image and the seven denoised estimates. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop. Use this section to really explore the results. Compare high intensity reigons to low intensity reigons, zoom in and out and spot the differences between the different samples. \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmin = np.percentile(test_set.numpy(), 1)\n", + "vmax = np.percentile(test_set.numpy(), 99)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(2, 4, figsize=(16, 8))\n", + "ax[0, 0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0, 0].set_title(\"Input\")\n", + "for i in range(n_samples):\n", + " ax[(i + 1) // 4, (i + 1) % 4].imshow(\n", + " samples[img_idx][i][crop], vmin=vmin, vmax=vmax\n", + " )\n", + " ax[(i + 1) // 4, (i + 1) % 4].set_title(f\"Sample {i+1}\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.2 MMSE estimate\n", + "\n", + "In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.3.\n", + "Set `n_samples` to 100 to average 100 images, or a different value to average a different number. Then visually inspeect the results. Examine how the MMSE result differs from the random sample.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = ... ### Insert an integer here\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()\n", + "MMSEs = torch.mean(samples, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = False\n", + "n_samples = 100 ### Insert an integer here\n", + "\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "samples = []\n", + "for _ in tqdm(range(n_samples)):\n", + " out = predictor.predict(hub, predict_loader)\n", + " out = torch.cat(out, dim=0)\n", + " samples.append(out)\n", + "\n", + "samples = torch.stack(samples, dim=1).half()\n", + "MMSEs = torch.mean(samples, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Input\")\n", + "ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Sample\")\n", + "ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"MMSE\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The MMSE will usually be closer to the reference than an individual sample and would score a higher PSNR, although it will also be blurrier." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4.3 Direct denoising\n", + "Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "### Task 2.4.\n", + "\n", + "Did you enable the direct denoiser in the previous notebook? If so, set `use_direct_denoiser` to `True` to use the Direct Denoiser for inference. If not, go back to section 2.3 to load the pretrained model and return here. \n", + "\n", + "Notice how much quicker the direct denoiser is than generating the MMSE results. Visually inspect and explore the results in the same way as before, notice how similar the direct estimate and MMSE estimate are.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = ... ### Insert a boolean here\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "direct = predictor.predict(hub, predict_loader)\n", + "direct = torch.cat(direct, dim=0).half()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_direct_denoiser = True ### Insert a boolean here\n", + "hub.direct_pred = use_direct_denoiser\n", + "\n", + "direct = predictor.predict(hub, predict_loader)\n", + "direct = torch.cat(direct, dim=0).half()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_idx = 0\n", + "top = 0\n", + "bottom = 1024\n", + "left = 0\n", + "right = 1024\n", + "\n", + "crop = (0, slice(top, bottom), slice(left, right))\n", + "\n", + "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[0].set_title(\"Input\")\n", + "ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[1].set_title(\"Direct\")\n", + "ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)\n", + "ax[2].set_title(\"MMSE\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.5. Incorrect receptive field\n", + "\n", + "We've now trained a model and used it to remove structured noise from our data. Before moving onto the next notebook, we'll look at what happens when a COSDD model is trained without considering the noise structures present. \n", + "\n", + "COSDD is able to separate imaging noise from clean signal because its autoregressive decoder has a receptive field that spans pixels containing correlated noise, i.e., the row or column of pixels. If its receptive field did not contain pixels with correlated noise, it would not be able to model them and they would be captured by the VAE's latent variables. To demonstrate this, the image below shows a Direct and MMSE estimate of a denoised image where the autoregressive decoder's receptive field was incorrectly set to vertical, leaving it unable to model horizontal noise." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Checkpoint 3\n", + "\n", + "We've completed the process of training and applying a COSDD model for denoising, but there's still more it can do. Optionally continue to the bonus notebook, bonus-exercise-generation.ipynb, to see how the model of the data can be used to generate new clean and noisy images.\n", + "\n", + "Otherwise, continue to 04_DenoiSplit.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/04_DenoiSplit/exercise.ipynb b/04_DenoiSplit/exercise.ipynb new file mode 100644 index 0000000..65b7aab --- /dev/null +++ b/04_DenoiSplit/exercise.ipynb @@ -0,0 +1,832 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "290f929e", + "metadata": {}, + "source": [ + "# denoiSplit: joint splitting and unsupervised denoising\n", + "In this notebook, we tackle the problem of joint splitting and unsupervised denoising, which has a use case in the field of fluorescence microscopy. From a technical perspective, given a noisy image $x$, the goal is to predict two images $c_1$ and $c_2$ such that $x = c_1 + c_2 + n$, where $n$ is the noise in $x$. In other words, we have a superimposed image $x$ and we want to predict the denoised estimates of the constituent images $c_1$ and $c_2$. It is important to note that the network is trained with noisy data and the denoising is done in a unsupervised manner. \n", + "\n", + "For this, we will use [denoiSplit](https://arxiv.org/pdf/2403.11854.pdf), a recently developed approach for this task. In this notebook we train denoiSplit and later evaluate it on one validation frame. The overall schema for denoiSplit is shown below:\n", + "\n", + "\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "9cd72d68", + "metadata": {}, + "source": [ + "Here, we look at CCPs (clathrin-coated pits) vs ER (Endoplasmic reticulum) task, one of the tasks tackled by denoiSplit which is generated from [BioSR](https://figshare.com/articles/dataset/BioSR/13264793) dataset.\n", + "\n", + "1) First, we will load both CCPs and ER images.
\n", + "2) We'll add synthetic Poisson and Gaussian noise to them. This simulates the noise that typically occurs in light microscopy.
\n", + "3) Each noisy CCPs image will be added to each corresponding ER image, making a superimposed image, $x$.
\n", + "4) A VSE network will be trained to take $x$ as input and return unsplit, denoised CCPs and ER images.\n", + "5) You'll inspect the results, then re-run the notebook with different noise levels and model hyper-parameters to see how performance changes." + ] + }, + { + "cell_type": "markdown", + "id": "2bedf584", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "76107363", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Set directories \n", + "In the next cell, we enumerate the necessary fields for this task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dbd8fb", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "work_dir = \".\"\n", + "tensorboard_log_dir = os.path.join(work_dir, \"tensorboard_logs\")\n", + "os.makedirs(tensorboard_log_dir, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e96ca6", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('./denoisplit')\n", + "\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "from torch.utils.data import DataLoader\n", + "import pytorch_lightning as pl\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "\n", + "from denoisplit.data_loader.vanilla_dloader import MultiChDloader\n", + "from denoisplit.analysis.plot_utils import clean_ax\n", + "from denoisplit.configs.biosr_config import get_config\n", + "from denoisplit.training import create_dataset\n", + "from denoisplit.nets.model_utils import create_model\n", + "from denoisplit.core.metric_monitor import MetricMonitor\n", + "from denoisplit.scripts.run import get_mean_std_dict_for_model\n", + "from denoisplit.core.data_split_type import DataSplitType\n", + "from denoisplit.scripts.evaluate import get_highsnr_data\n", + "from denoisplit.analysis.mmse_prediction import get_dset_predictions\n", + "from denoisplit.data_loader.patch_index_manager import GridAlignement\n", + "from denoisplit.scripts.evaluate import avg_range_inv_psnr, compute_multiscale_ssim" + ] + }, + { + "cell_type": "markdown", + "id": "ff6b18b2", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "\n", + "

\n", + " Several Things to try:

\n", + "
    \n", + "
  1. Run once with unchanged config to see the performance.
  2. \n", + "
  3. Increase the noise (double the gaussian noise?) and see how performance degrades.
  4. \n", + "
      \n", + "
    1. Recap: Poisson and Gaussian are the two most prominant pixelwise independent noise sources. Here, we encorporate both. Note that the larger the noise, the harder the task becomes.
    2. \n", + "
    \n", + "
  5. Increase the max_epochs, if you want to get better performance.
  6. \n", + "
  7. For faster training ( but compromising on performance), reduce the number of hierarchy levels and/or the channel count by modifying config.model.z_dims.
  8. \n", + "
  9. First we train the model to split CCPs and ER channels. Later you can try to split other channels, e.g. F-actin and ER. You'll be able to see that this is a substantially harder task.
  10. \n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "788a6142", + "metadata": {}, + "source": [ + "## Config " + ] + }, + { + "cell_type": "markdown", + "id": "d1a5e742", + "metadata": {}, + "source": [ + "Here we'll load the data and set model hyper-parameters.\n", + "To create the dataset, we'll load two sets of images: CCPs (clathrin-coated pits) and ER (Endoplasmic reticulum). \n", + "Each image from the CCPs will be added to an image from ER, then noise added on top.\n", + "\n", + "The level of noise is determined by `config.data.poisson_noise_factor` and `config.data.synthetic_gaussian_scale`.\n", + "The former simulates photon shot noise, which is more destructive on lower intensity signals.\n", + "The latter simulates electronic read noise, which has a constant variance for all signal intensities.\n" + ] + }, + { + "cell_type": "markdown", + "id": "04735f11", + "metadata": {}, + "source": [ + "`config.data.poisson_noise_factor` (float): the intensity of the Poisson (shot) noise.\n", + "\n", + "`config.data.synthetic_gaussian_scale` (float): the intensity of the Gaussian (readout) noise.\n", + "\n", + "`config.model.z_dims` (list(int)): Determines the depth of our network. The number of entries is the number of levels. The value of each entry is the number of hidden dimensions at each level.\n", + "\n", + "`config.training.lr` (float): The learning rate.\n", + "\n", + "`config.training.max_epochs` (int): Number of training epochs. Increase for better performance, decrease for shorter training time.\n", + "\n", + "`config.training.batch_size` (int): Training batch size. Increasing this will require more memory. Performance may improve, but bigger batches aren't always better.\n", + "\n", + "`config.training.num_workers` (int): Number of subprocesses to use for data loading. This is different for different GPUs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcca8dc2", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "datapath = \"./../data/\"\n", + "\n", + "# load the default config.\n", + "config = get_config()\n", + "\n", + "config.data.ch1_fname = 'ER/GT_all.mrc'\n", + "config.data.ch2_fname = 'CCPs/GT_all.mrc'\n", + "# Channge the noise level\n", + "config.data.poisson_noise_factor = (\n", + " 1000 # 1000 is the default value. noise increases with the value.\n", + ")\n", + "config.data.synthetic_gaussian_scale = (\n", + " 5000 # 5000 is the default value. noise increases with the value.\n", + ")\n", + "\n", + "# change the number of hierarchy levels.\n", + "config.model.z_dims = [128, 128, 128, 128]\n", + "\n", + "# change the training parameters\n", + "config.training.lr = 3e-3\n", + "config.training.max_epochs = 10\n", + "config.training.batch_size = 8\n", + "config.training.num_workers = 4\n", + "\n", + "config.workdir = \".\"" + ] + }, + { + "cell_type": "markdown", + "id": "e83242ab", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Create the dataset and pytorch dataloaders. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "962f9c25", + "metadata": {}, + "outputs": [], + "source": [ + "print(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ff3cd66", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "train_dset = MultiChDloader(config.data,\n", + " datapath,\n", + " datasplit_type=DataSplitType.Train,\n", + " val_fraction=config.training.val_fraction,\n", + " test_fraction=config.training.test_fraction,\n", + " normalized_input=config.data.normalized_input,\n", + " use_one_mu_std=config.data.use_one_mu_std,\n", + " enable_rotation_aug=config.data.train_aug_rotate\n", + " )\n", + "val_dset = MultiChDloader(config.data,\n", + " datapath,\n", + " datasplit_type=DataSplitType.Val,\n", + " val_fraction=config.training.val_fraction,\n", + " test_fraction=config.training.test_fraction,\n", + " normalized_input=config.data.normalized_input,\n", + " use_one_mu_std=config.data.use_one_mu_std,\n", + " enable_rotation_aug=False, # No rotation aug on validation\n", + " max_val=train_dset.get_max_val(),\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b47e62c", + "metadata": {}, + "outputs": [], + "source": [ + "mean_dict, std_dict = train_dset.compute_mean_std()\n", + "train_dset.set_mean_std(mean_dict, std_dict)\n", + "val_dset.set_mean_std(mean_dict, std_dict)\n", + "\n", + "mean_dict, std_dict = get_mean_std_dict_for_model(config, train_dset)" + ] + }, + { + "cell_type": "markdown", + "id": "28072b04", + "metadata": {}, + "source": [ + "## Inspecting the training data generated using the above config.\n", + "
\n", + "If you want to change the noise, then you should change the config first and run the following cell again to see how the training data changes in terms of noise.\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3d6ef80", + "metadata": {}, + "outputs": [], + "source": [ + "val_dset.set_img_sz(800, 64)\n", + "inp, tar = val_dset[0]\n", + "_,ax = plt.subplots(1,3, figsize=(15,5))\n", + "ax[0].imshow(inp[0], cmap='magma')\n", + "ax[0].set_title('Input')\n", + "ax[1].imshow(tar[0], cmap='magma')\n", + "ax[1].set_title('Channel 1')\n", + "ax[2].imshow(tar[1], cmap='magma')\n", + "ax[2].set_title('Channel 2')\n", + "\n", + "val_dset.set_img_sz(config.data.image_size, config.data.image_size)" + ] + }, + { + "cell_type": "markdown", + "id": "c6b959db", + "metadata": {}, + "source": [ + "## Define the dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09035708", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "batch_size = config.training.batch_size\n", + "train_dloader = DataLoader(\n", + " train_dset,\n", + " pin_memory=False,\n", + " num_workers=config.training.num_workers,\n", + " shuffle=True,\n", + " batch_size=batch_size,\n", + ")\n", + "val_dloader = DataLoader(\n", + " val_dset,\n", + " pin_memory=False,\n", + " num_workers=config.training.num_workers,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0dc243f", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Create the model.\n", + "Here, we instantiate the [denoiSplit model](https://arxiv.org/pdf/2403.11854.pdf). For simplicity, we have disabled the noise model. For enabling the noise model, one would additionally have to train a denoiser. The next step would be to create a noise model using the noisy data and the corresponding denoised predictions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cec5ec5", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "model = create_model(config, mean_dict, std_dict)\n", + "model = model.cuda()" + ] + }, + { + "cell_type": "markdown", + "id": "f9cde1e7", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Start training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "817e538b", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "logger = TensorBoardLogger(tensorboard_log_dir, name=\"\", version=\"\", default_hp_metric=False)\n", + "trainer = pl.Trainer(\n", + " max_epochs=config.training.max_epochs,\n", + " gradient_clip_val=(\n", + " None\n", + " if not model.automatic_optimization\n", + " else config.training.grad_clip_norm_value\n", + " ),\n", + " logger=logger,\n", + " precision=config.training.precision,\n", + ")\n", + "trainer.fit(model, train_dloader, val_dloader)" + ] + }, + { + "cell_type": "markdown", + "id": "3c06421e", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Evaluate the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eca722a9", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "model.eval()\n", + "_ = model.cuda()\n", + "eval_frame_idx = 0\n", + "# reducing the data, just for speed\n", + "val_dset.reduce_data(t_list=[eval_frame_idx])\n", + "mmse_count = 10\n", + "overlapping_padding_kwargs = {\n", + " \"mode\": config.data.get(\"padding_mode\", \"constant\"),\n", + "}\n", + "if overlapping_padding_kwargs[\"mode\"] == \"constant\":\n", + " overlapping_padding_kwargs[\"constant_values\"] = config.data.get(\"padding_value\", 0)\n", + "val_dset.set_img_sz(\n", + " 128,\n", + " 32,\n", + " grid_alignment=GridAlignement.Center,\n", + " overlapping_padding_kwargs=overlapping_padding_kwargs,\n", + ")\n", + "\n", + "# MMSE prediction\n", + "pred_tiled, rec_loss, logvar_tiled, patch_psnr_tuple, pred_std_tiled = (\n", + " get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=mmse_count,\n", + " model_type=config.model.model_type,\n", + " )\n", + ")\n", + "\n", + "# One sample prediction\n", + "pred1_tiled, *_ = get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=1,\n", + " model_type=config.model.model_type,\n", + ")\n", + "# One sample prediction\n", + "pred2_tiled, *_ = get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=1,\n", + " model_type=config.model.model_type,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c1c9bd5b", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Stitching predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38df4c25", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "from denoisplit.analysis.stitch_prediction import stitch_predictions\n", + "\n", + "pred = stitch_predictions(pred_tiled, val_dset)\n", + "\n", + "\n", + "# ignore pixels at the [right/bottom] boundary.\n", + "def print_ignored_pixels():\n", + " ignored_pixels = 1\n", + " while (\n", + " pred[\n", + " 0,\n", + " -ignored_pixels:,\n", + " -ignored_pixels:,\n", + " ].std()\n", + " == 0\n", + " ):\n", + " ignored_pixels += 1\n", + " ignored_pixels -= 1\n", + " return ignored_pixels\n", + "\n", + "\n", + "actual_ignored_pixels = print_ignored_pixels()\n", + "pred = pred[:, :-actual_ignored_pixels, :-actual_ignored_pixels]\n", + "pred1 = stitch_predictions(pred1_tiled, val_dset)[\n", + " :, :-actual_ignored_pixels, :-actual_ignored_pixels\n", + "]\n", + "pred2 = stitch_predictions(pred2_tiled, val_dset)[\n", + " :, :-actual_ignored_pixels, :-actual_ignored_pixels\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "f451a4e1", + "metadata": {}, + "source": [ + "## Get the ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d0866ed", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "highres_data = get_highsnr_data(config, datapath, DataSplitType.Val)\n", + "\n", + "highres_data = highres_data[\n", + " eval_frame_idx : eval_frame_idx + 1,\n", + " :-actual_ignored_pixels,\n", + " :-actual_ignored_pixels,\n", + "]\n", + "\n", + "noisy_data = val_dset._noise_data[..., 1:] + val_dset._data\n", + "noisy_data = noisy_data[..., :-actual_ignored_pixels, :-actual_ignored_pixels, :]\n", + "model_input = np.mean(noisy_data, axis=-1)" + ] + }, + { + "cell_type": "markdown", + "id": "dc0a5837", + "metadata": {}, + "source": [ + "\n", + "

Checkpoint 1: Model trained

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "bb4a0b75", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "# Qualitative performance on a random crop\n", + "denoiSplit is capable of sampling from a learned posterior.\n", + "Here we show full input frame and a randomly cropped input (300*300),\n", + "two corresponding prediction samples, the difference between the two samples (S1−S2),\n", + "the MMSE prediction, and otherwise unused high SNR microscopy crop. \n", + "The MMSE predictions are computed by averaging 10 samples. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29754975", + "metadata": {}, + "outputs": [], + "source": [ + "def add_str(ax_, txt):\n", + " \"\"\"\n", + " Add psnr string to the axes\n", + " \"\"\"\n", + " textstr = txt\n", + " props = dict(boxstyle=\"round\", facecolor=\"gray\", alpha=0.5)\n", + " # place a text box in upper left in axes coords\n", + " ax_.text(\n", + " 0.05,\n", + " 0.95,\n", + " textstr,\n", + " transform=ax_.transAxes,\n", + " fontsize=11,\n", + " verticalalignment=\"top\",\n", + " bbox=props,\n", + " color=\"white\",\n", + " )\n", + "\n", + "\n", + "ncols = 7\n", + "nrows = 2\n", + "sz = 300\n", + "hs = np.random.randint(0, highres_data.shape[1] - sz)\n", + "ws = np.random.randint(0, highres_data.shape[2] - sz)\n", + "_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))\n", + "ax[0, 0].imshow(model_input[0], cmap=\"magma\")\n", + "\n", + "rect = patches.Rectangle((ws, hs), sz, sz, linewidth=1, edgecolor=\"r\", facecolor=\"none\")\n", + "ax[0, 0].add_patch(rect)\n", + "ax[1, 0].imshow(model_input[0, hs : hs + sz, ws : ws + sz], cmap=\"magma\")\n", + "add_str(ax[0, 0], \"Full Input Frame\")\n", + "add_str(ax[1, 0], \"Random Input Crop\")\n", + "\n", + "ax[0, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "ax[0, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "ax[0, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "diff = pred2 - pred1\n", + "ax[0, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"coolwarm\")\n", + "ax[1, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"coolwarm\")\n", + "\n", + "ax[0, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "\n", + "ax[0, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "plt.subplots_adjust(wspace=0.02, hspace=0.02)\n", + "ax[0, 0].set_title(\"Model Input\", size=13)\n", + "ax[0, 1].set_title(\"Target\", size=13)\n", + "ax[0, 2].set_title(\"Sample 1 (S1)\", size=13)\n", + "ax[0, 3].set_title(\"Sample 2 (S2)\", size=13)\n", + "ax[0, 4].set_title('\"S2\" - \"S1\"', size=13)\n", + "ax[0, 5].set_title(f\"Prediction MMSE({mmse_count})\", size=13)\n", + "ax[0, 6].set_title(\"High SNR Reality\", size=13)\n", + "\n", + "twinx = ax[0, 6].twinx()\n", + "twinx.set_ylabel(\"Channel 1\", size=13)\n", + "clean_ax(twinx)\n", + "twinx = ax[1, 6].twinx()\n", + "twinx.set_ylabel(\"Channel 2\", size=13)\n", + "clean_ax(twinx)\n", + "clean_ax(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "d1ce25bb", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "# Qualitative performance on multiple random crops\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dc1e50d", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "nimgs = 3\n", + "ncols = 7\n", + "nrows = 2 * nimgs\n", + "sz = 300\n", + "_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))\n", + "\n", + "for img_idx in range(nimgs):\n", + " hs = np.random.randint(0, highres_data.shape[1] - sz)\n", + " ws = np.random.randint(0, highres_data.shape[2] - sz)\n", + " ax[2 * img_idx, 0].imshow(model_input[0], cmap=\"magma\")\n", + "\n", + " rect = patches.Rectangle(\n", + " (ws, hs), sz, sz, linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n", + " )\n", + " ax[2 * img_idx, 0].add_patch(rect)\n", + " ax[2 * img_idx + 1, 0].imshow(\n", + " model_input[0, hs : hs + sz, ws : ws + sz], cmap=\"magma\"\n", + " )\n", + " add_str(ax[2 * img_idx, 0], \"Full Input Frame\")\n", + " add_str(ax[2 * img_idx + 1, 0], \"Random Input Crop\")\n", + "\n", + " ax[2 * img_idx, 1].imshow(\n", + " noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\"\n", + " )\n", + " ax[2 * img_idx + 1, 1].imshow(\n", + " noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\"\n", + " )\n", + "\n", + " ax[2 * img_idx, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " ax[2 * img_idx, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " diff = pred2 - pred1\n", + " ax[2 * img_idx, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"coolwarm\")\n", + " ax[2 * img_idx + 1, 4].imshow(\n", + " diff[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"coolwarm\"\n", + " )\n", + "\n", + " ax[2 * img_idx, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " ax[2 * img_idx, 6].imshow(\n", + " highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\"\n", + " )\n", + " ax[2 * img_idx + 1, 6].imshow(\n", + " highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\"\n", + " )\n", + "\n", + " twinx = ax[2 * img_idx, 6].twinx()\n", + " twinx.set_ylabel(\"Channel 1\", size=15)\n", + " clean_ax(twinx)\n", + "\n", + " twinx = ax[2 * img_idx + 1, 6].twinx()\n", + " twinx.set_ylabel(\"Channel 2\", size=15)\n", + " clean_ax(twinx)\n", + "\n", + "ax[0, 0].set_title(\"Model Input\", size=15)\n", + "ax[0, 1].set_title(\"Target\", size=15)\n", + "ax[0, 2].set_title(\"Sample 1 (S1)\", size=15)\n", + "ax[0, 3].set_title(\"Sample 2 (S2)\", size=15)\n", + "ax[0, 4].set_title('\"S2\" - \"S1\"', size=15)\n", + "ax[0, 5].set_title(f\"Prediction MMSE({mmse_count})\", size=15)\n", + "ax[0, 6].set_title(\"High SNR Reality\", size=15)\n", + "\n", + "clean_ax(ax)\n", + "plt.subplots_adjust(wspace=0.02, hspace=0.02)\n", + "# plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "3db2fe0b", + "metadata": {}, + "source": [ + "
\n", + "

Questions:

\n", + " 1) When is it relatively easy to split the two structures from the input?
\n", + " 2) Why might you see the grid-like artifacts and what can be done to mitigate this?
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "d94ea88e", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Quantitative performance\n", + "We evaluate on two metrics, Multiscale SSIM and PSNR.\n", + "\n", + "Multi-scale SSIM is a metric that computes SSIM at multiple scales and averages them. It's reminiscent of multiscale processing in the early vision system \n", + "\n", + "PSNR is a metric that computes the peak signal-to-noise ratio. It's one of the most widely used metrics to measure the quality of image reconstruction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc3246db", + "metadata": {}, + "outputs": [], + "source": [ + "mean_tar = mean_dict[\"target\"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)\n", + "std_tar = std_dict[\"target\"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)\n", + "pred_unnorm = pred * std_tar + mean_tar\n", + "\n", + "psnr_list = [\n", + " avg_range_inv_psnr(highres_data[..., i].copy(), pred_unnorm[..., i].copy())\n", + " for i in range(highres_data.shape[-1])\n", + "]\n", + "ssim_list = compute_multiscale_ssim(highres_data.copy(), pred_unnorm.copy())\n", + "print(\"Metric: Ch1\\t Ch2\")\n", + "print(f\"PSNR : {psnr_list[0]:.2f}\\t {psnr_list[1]:.2f}\")\n", + "print(f\"MS-SSIM : {ssim_list[0]:.3f}\\t {ssim_list[1]:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ea1bdbbe", + "metadata": {}, + "source": [ + "

Checkpoint 2: Try one of the \"Several things to try\"

\n", + "\n", + "
\n", + "\n", + "Click [here](#things-to-try) to go back to the relevant section." + ] + }, + { + "cell_type": "markdown", + "id": "37ef9a4b", + "metadata": {}, + "source": [ + "

End of the exercise

\n", + "
" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "all", + "main_language": "python" + }, + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/04_DenoiSplit/imgs/teaser.png b/04_DenoiSplit/imgs/teaser.png new file mode 100755 index 0000000..07ae47b Binary files /dev/null and b/04_DenoiSplit/imgs/teaser.png differ diff --git a/04_DenoiSplit/solution.ipynb b/04_DenoiSplit/solution.ipynb new file mode 100755 index 0000000..3ba040d --- /dev/null +++ b/04_DenoiSplit/solution.ipynb @@ -0,0 +1,847 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "290f929e", + "metadata": {}, + "source": [ + "# denoiSplit: joint splitting and unsupervised denoising\n", + "In this notebook, we tackle the problem of joint splitting and unsupervised denoising, which has a use case in the field of fluorescence microscopy. From a technical perspective, given a noisy image $x$, the goal is to predict two images $c_1$ and $c_2$ such that $x = c_1 + c_2 + n$, where $n$ is the noise in $x$. In other words, we have a superimposed image $x$ and we want to predict the denoised estimates of the constituent images $c_1$ and $c_2$. It is important to note that the network is trained with noisy data and the denoising is done in a unsupervised manner. \n", + "\n", + "For this, we will use [denoiSplit](https://arxiv.org/pdf/2403.11854.pdf), a recently developed approach for this task. In this notebook we train denoiSplit and later evaluate it on one validation frame. The overall schema for denoiSplit is shown below:\n", + "\n", + "\n", + "\"drawing\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "9cd72d68", + "metadata": {}, + "source": [ + "Here, we look at CCPs (clathrin-coated pits) vs ER (Endoplasmic reticulum) task, one of the tasks tackled by denoiSplit which is generated from [BioSR](https://figshare.com/articles/dataset/BioSR/13264793) dataset.\n", + "\n", + "1) First, we will load both CCPs and ER images.
\n", + "2) We'll add synthetic Poisson and Gaussian noise to them. This simulates the noise that typically occurs in light microscopy.
\n", + "3) Each noisy CCPs image will be added to each corresponding ER image, making a superimposed image, $x$.
\n", + "4) A VSE network will be trained to take $x$ as input and return unsplit, denoised CCPs and ER images.\n", + "5) You'll inspect the results, then re-run the notebook with different noise levels and model hyper-parameters to see how performance changes." + ] + }, + { + "cell_type": "markdown", + "id": "2bedf584", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "76107363", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Set directories \n", + "In the next cell, we enumerate the necessary fields for this task." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47dbd8fb", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "work_dir = \".\"\n", + "tensorboard_log_dir = os.path.join(work_dir, \"tensorboard_logs\")\n", + "os.makedirs(tensorboard_log_dir, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e96ca6", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('./denoisplit')\n", + "\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "from torch.utils.data import DataLoader\n", + "import pytorch_lightning as pl\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "\n", + "from denoisplit.data_loader.vanilla_dloader import MultiChDloader\n", + "from denoisplit.analysis.plot_utils import clean_ax\n", + "from denoisplit.configs.biosr_config import get_config\n", + "from denoisplit.training import create_dataset\n", + "from denoisplit.nets.model_utils import create_model\n", + "from denoisplit.core.metric_monitor import MetricMonitor\n", + "from denoisplit.scripts.run import get_mean_std_dict_for_model\n", + "from denoisplit.core.data_split_type import DataSplitType\n", + "from denoisplit.scripts.evaluate import get_highsnr_data\n", + "from denoisplit.analysis.mmse_prediction import get_dset_predictions\n", + "from denoisplit.data_loader.patch_index_manager import GridAlignement\n", + "from denoisplit.scripts.evaluate import avg_range_inv_psnr, compute_multiscale_ssim" + ] + }, + { + "cell_type": "markdown", + "id": "ff6b18b2", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "\n", + "

\n", + " Several Things to try:

\n", + "
    \n", + "
  1. Run once with unchanged config to see the performance.
  2. \n", + "
  3. Increase the noise (double the gaussian noise?) and see how performance degrades.
  4. \n", + "
      \n", + "
    1. Recap: Poisson and Gaussian are the two most prominant pixelwise independent noise sources. Here, we encorporate both. Note that the larger the noise, the harder the task becomes.
    2. \n", + "
    \n", + "
  5. Increase the max_epochs, if you want to get better performance.
  6. \n", + "
  7. For faster training ( but compromising on performance), reduce the number of hierarchy levels and/or the channel count by modifying config.model.z_dims.
  8. \n", + "
  9. First we train the model to split CCPs and ER channels. Later you can try to split other channels, e.g. F-actin and ER. You'll be able to see that this is a substantially harder task.
  10. \n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "788a6142", + "metadata": {}, + "source": [ + "## Config " + ] + }, + { + "cell_type": "markdown", + "id": "d1a5e742", + "metadata": {}, + "source": [ + "Here we'll load the data and set model hyper-parameters.\n", + "To create the dataset, we'll load two sets of images: CCPs (clathrin-coated pits) and ER (Endoplasmic reticulum). \n", + "Each image from the CCPs will be added to an image from ER, then noise added on top.\n", + "\n", + "The level of noise is determined by `config.data.poisson_noise_factor` and `config.data.synthetic_gaussian_scale`.\n", + "The former simulates photon shot noise, which is more destructive on lower intensity signals.\n", + "The latter simulates electronic read noise, which has a constant variance for all signal intensities.\n" + ] + }, + { + "cell_type": "markdown", + "id": "04735f11", + "metadata": {}, + "source": [ + "`config.data.poisson_noise_factor` (float): the intensity of the Poisson (shot) noise.\n", + "\n", + "`config.data.synthetic_gaussian_scale` (float): the intensity of the Gaussian (readout) noise.\n", + "\n", + "`config.model.z_dims` (list(int)): Determines the depth of our network. The number of entries is the number of levels. The value of each entry is the number of hidden dimensions at each level.\n", + "\n", + "`config.training.lr` (float): The learning rate.\n", + "\n", + "`config.training.max_epochs` (int): Number of training epochs. Increase for better performance, decrease for shorter training time.\n", + "\n", + "`config.training.batch_size` (int): Training batch size. Increasing this will require more memory. Performance may improve, but bigger batches aren't always better.\n", + "\n", + "`config.training.num_workers` (int): Number of subprocesses to use for data loading. This is different for different GPUs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bcca8dc2", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "datapath = \"./../data/\"\n", + "\n", + "# load the default config.\n", + "config = get_config()\n", + "\n", + "config.data.ch1_fname = 'ER/GT_all.mrc'\n", + "config.data.ch2_fname = 'CCPs/GT_all.mrc'\n", + "# Channge the noise level\n", + "config.data.poisson_noise_factor = (\n", + " 1000 # 1000 is the default value. noise increases with the value.\n", + ")\n", + "config.data.synthetic_gaussian_scale = (\n", + " 5000 # 5000 is the default value. noise increases with the value.\n", + ")\n", + "\n", + "# change the number of hierarchy levels.\n", + "config.model.z_dims = [128, 128, 128, 128]\n", + "\n", + "# change the training parameters\n", + "config.training.lr = 3e-3\n", + "config.training.max_epochs = 10\n", + "config.training.batch_size = 8\n", + "config.training.num_workers = 4\n", + "\n", + "config.workdir = \".\"" + ] + }, + { + "cell_type": "markdown", + "id": "e83242ab", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Create the dataset and pytorch dataloaders. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "962f9c25", + "metadata": {}, + "outputs": [], + "source": [ + "print(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ff3cd66", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "train_dset = MultiChDloader(config.data,\n", + " datapath,\n", + " datasplit_type=DataSplitType.Train,\n", + " val_fraction=config.training.val_fraction,\n", + " test_fraction=config.training.test_fraction,\n", + " normalized_input=config.data.normalized_input,\n", + " use_one_mu_std=config.data.use_one_mu_std,\n", + " enable_rotation_aug=config.data.train_aug_rotate\n", + " )\n", + "val_dset = MultiChDloader(config.data,\n", + " datapath,\n", + " datasplit_type=DataSplitType.Val,\n", + " val_fraction=config.training.val_fraction,\n", + " test_fraction=config.training.test_fraction,\n", + " normalized_input=config.data.normalized_input,\n", + " use_one_mu_std=config.data.use_one_mu_std,\n", + " enable_rotation_aug=False, # No rotation aug on validation\n", + " max_val=train_dset.get_max_val(),\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b47e62c", + "metadata": {}, + "outputs": [], + "source": [ + "mean_dict, std_dict = train_dset.compute_mean_std()\n", + "train_dset.set_mean_std(mean_dict, std_dict)\n", + "val_dset.set_mean_std(mean_dict, std_dict)\n", + "\n", + "mean_dict, std_dict = get_mean_std_dict_for_model(config, train_dset)" + ] + }, + { + "cell_type": "markdown", + "id": "28072b04", + "metadata": {}, + "source": [ + "## Inspecting the training data generated using the above config.\n", + "
\n", + "If you want to change the noise, then you should change the config first and run the following cell again to see how the training data changes in terms of noise.\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3d6ef80", + "metadata": {}, + "outputs": [], + "source": [ + "val_dset.set_img_sz(800, 64)\n", + "inp, tar = val_dset[0]\n", + "_,ax = plt.subplots(1,3, figsize=(15,5))\n", + "ax[0].imshow(inp[0], cmap='magma')\n", + "ax[0].set_title('Input')\n", + "ax[1].imshow(tar[0], cmap='magma')\n", + "ax[1].set_title('Channel 1')\n", + "ax[2].imshow(tar[1], cmap='magma')\n", + "ax[2].set_title('Channel 2')\n", + "\n", + "val_dset.set_img_sz(config.data.image_size, config.data.image_size)" + ] + }, + { + "cell_type": "markdown", + "id": "c6b959db", + "metadata": {}, + "source": [ + "## Define the dataloaders" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09035708", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "batch_size = config.training.batch_size\n", + "train_dloader = DataLoader(\n", + " train_dset,\n", + " pin_memory=False,\n", + " num_workers=config.training.num_workers,\n", + " shuffle=True,\n", + " batch_size=batch_size,\n", + ")\n", + "val_dloader = DataLoader(\n", + " val_dset,\n", + " pin_memory=False,\n", + " num_workers=config.training.num_workers,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0dc243f", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Create the model.\n", + "Here, we instantiate the [denoiSplit model](https://arxiv.org/pdf/2403.11854.pdf). For simplicity, we have disabled the noise model. For enabling the noise model, one would additionally have to train a denoiser. The next step would be to create a noise model using the noisy data and the corresponding denoised predictions. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cec5ec5", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "model = create_model(config, mean_dict, std_dict)\n", + "model = model.cuda()" + ] + }, + { + "cell_type": "markdown", + "id": "f9cde1e7", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Start training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "817e538b", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "logger = TensorBoardLogger(tensorboard_log_dir, name=\"\", version=\"\", default_hp_metric=False)\n", + "trainer = pl.Trainer(\n", + " max_epochs=config.training.max_epochs,\n", + " gradient_clip_val=(\n", + " None\n", + " if not model.automatic_optimization\n", + " else config.training.grad_clip_norm_value\n", + " ),\n", + " logger=logger,\n", + " precision=config.training.precision,\n", + ")\n", + "trainer.fit(model, train_dloader, val_dloader)" + ] + }, + { + "cell_type": "markdown", + "id": "3c06421e", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Evaluate the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eca722a9", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "model.eval()\n", + "_ = model.cuda()\n", + "eval_frame_idx = 0\n", + "# reducing the data, just for speed\n", + "val_dset.reduce_data(t_list=[eval_frame_idx])\n", + "mmse_count = 10\n", + "overlapping_padding_kwargs = {\n", + " \"mode\": config.data.get(\"padding_mode\", \"constant\"),\n", + "}\n", + "if overlapping_padding_kwargs[\"mode\"] == \"constant\":\n", + " overlapping_padding_kwargs[\"constant_values\"] = config.data.get(\"padding_value\", 0)\n", + "val_dset.set_img_sz(\n", + " 128,\n", + " 32,\n", + " grid_alignment=GridAlignement.Center,\n", + " overlapping_padding_kwargs=overlapping_padding_kwargs,\n", + ")\n", + "\n", + "# MMSE prediction\n", + "pred_tiled, rec_loss, logvar_tiled, patch_psnr_tuple, pred_std_tiled = (\n", + " get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=mmse_count,\n", + " model_type=config.model.model_type,\n", + " )\n", + ")\n", + "\n", + "# One sample prediction\n", + "pred1_tiled, *_ = get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=1,\n", + " model_type=config.model.model_type,\n", + ")\n", + "# One sample prediction\n", + "pred2_tiled, *_ = get_dset_predictions(\n", + " model,\n", + " val_dset,\n", + " batch_size,\n", + " num_workers=config.training.num_workers,\n", + " mmse_count=1,\n", + " model_type=config.model.model_type,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c1c9bd5b", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Stitching predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38df4c25", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "from denoisplit.analysis.stitch_prediction import stitch_predictions\n", + "\n", + "pred = stitch_predictions(pred_tiled, val_dset)\n", + "\n", + "\n", + "# ignore pixels at the [right/bottom] boundary.\n", + "def print_ignored_pixels():\n", + " ignored_pixels = 1\n", + " while (\n", + " pred[\n", + " 0,\n", + " -ignored_pixels:,\n", + " -ignored_pixels:,\n", + " ].std()\n", + " == 0\n", + " ):\n", + " ignored_pixels += 1\n", + " ignored_pixels -= 1\n", + " return ignored_pixels\n", + "\n", + "\n", + "actual_ignored_pixels = print_ignored_pixels()\n", + "pred = pred[:, :-actual_ignored_pixels, :-actual_ignored_pixels]\n", + "pred1 = stitch_predictions(pred1_tiled, val_dset)[\n", + " :, :-actual_ignored_pixels, :-actual_ignored_pixels\n", + "]\n", + "pred2 = stitch_predictions(pred2_tiled, val_dset)[\n", + " :, :-actual_ignored_pixels, :-actual_ignored_pixels\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "f451a4e1", + "metadata": {}, + "source": [ + "## Get the ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d0866ed", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "highres_data = get_highsnr_data(config, datapath, DataSplitType.Val)\n", + "\n", + "highres_data = highres_data[\n", + " eval_frame_idx : eval_frame_idx + 1,\n", + " :-actual_ignored_pixels,\n", + " :-actual_ignored_pixels,\n", + "]\n", + "\n", + "noisy_data = val_dset._noise_data[..., 1:] + val_dset._data\n", + "noisy_data = noisy_data[..., :-actual_ignored_pixels, :-actual_ignored_pixels, :]\n", + "model_input = np.mean(noisy_data, axis=-1)" + ] + }, + { + "cell_type": "markdown", + "id": "dc0a5837", + "metadata": {}, + "source": [ + "\n", + "

Checkpoint 1: Model trained

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "bb4a0b75", + "metadata": { + "cell_marker": "\"\"\"" + }, + "source": [ + "# Qualitative performance on a random crop\n", + "denoiSplit is capable of sampling from a learned posterior.\n", + "Here we show full input frame and a randomly cropped input (300*300),\n", + "two corresponding prediction samples, the difference between the two samples (S1−S2),\n", + "the MMSE prediction, and otherwise unused high SNR microscopy crop. \n", + "The MMSE predictions are computed by averaging 10 samples. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29754975", + "metadata": {}, + "outputs": [], + "source": [ + "def add_str(ax_, txt):\n", + " \"\"\"\n", + " Add psnr string to the axes\n", + " \"\"\"\n", + " textstr = txt\n", + " props = dict(boxstyle=\"round\", facecolor=\"gray\", alpha=0.5)\n", + " # place a text box in upper left in axes coords\n", + " ax_.text(\n", + " 0.05,\n", + " 0.95,\n", + " textstr,\n", + " transform=ax_.transAxes,\n", + " fontsize=11,\n", + " verticalalignment=\"top\",\n", + " bbox=props,\n", + " color=\"white\",\n", + " )\n", + "\n", + "\n", + "ncols = 7\n", + "nrows = 2\n", + "sz = 300\n", + "hs = np.random.randint(0, highres_data.shape[1] - sz)\n", + "ws = np.random.randint(0, highres_data.shape[2] - sz)\n", + "_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))\n", + "ax[0, 0].imshow(model_input[0], cmap=\"magma\")\n", + "\n", + "rect = patches.Rectangle((ws, hs), sz, sz, linewidth=1, edgecolor=\"r\", facecolor=\"none\")\n", + "ax[0, 0].add_patch(rect)\n", + "ax[1, 0].imshow(model_input[0, hs : hs + sz, ws : ws + sz], cmap=\"magma\")\n", + "add_str(ax[0, 0], \"Full Input Frame\")\n", + "add_str(ax[1, 0], \"Random Input Crop\")\n", + "\n", + "ax[0, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "ax[0, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "ax[0, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "diff = pred2 - pred1\n", + "ax[0, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"coolwarm\")\n", + "ax[1, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"coolwarm\")\n", + "\n", + "ax[0, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + "\n", + "ax[0, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + "ax[1, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "plt.subplots_adjust(wspace=0.02, hspace=0.02)\n", + "ax[0, 0].set_title(\"Model Input\", size=13)\n", + "ax[0, 1].set_title(\"Target\", size=13)\n", + "ax[0, 2].set_title(\"Sample 1 (S1)\", size=13)\n", + "ax[0, 3].set_title(\"Sample 2 (S2)\", size=13)\n", + "ax[0, 4].set_title('\"S2\" - \"S1\"', size=13)\n", + "ax[0, 5].set_title(f\"Prediction MMSE({mmse_count})\", size=13)\n", + "ax[0, 6].set_title(\"High SNR Reality\", size=13)\n", + "\n", + "twinx = ax[0, 6].twinx()\n", + "twinx.set_ylabel(\"Channel 1\", size=13)\n", + "clean_ax(twinx)\n", + "twinx = ax[1, 6].twinx()\n", + "twinx.set_ylabel(\"Channel 2\", size=13)\n", + "clean_ax(twinx)\n", + "clean_ax(ax)" + ] + }, + { + "cell_type": "markdown", + "id": "d1ce25bb", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "# Qualitative performance on multiple random crops\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dc1e50d", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "nimgs = 3\n", + "ncols = 7\n", + "nrows = 2 * nimgs\n", + "sz = 300\n", + "_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))\n", + "\n", + "for img_idx in range(nimgs):\n", + " hs = np.random.randint(0, highres_data.shape[1] - sz)\n", + " ws = np.random.randint(0, highres_data.shape[2] - sz)\n", + " ax[2 * img_idx, 0].imshow(model_input[0], cmap=\"magma\")\n", + "\n", + " rect = patches.Rectangle(\n", + " (ws, hs), sz, sz, linewidth=1, edgecolor=\"r\", facecolor=\"none\"\n", + " )\n", + " ax[2 * img_idx, 0].add_patch(rect)\n", + " ax[2 * img_idx + 1, 0].imshow(\n", + " model_input[0, hs : hs + sz, ws : ws + sz], cmap=\"magma\"\n", + " )\n", + " add_str(ax[2 * img_idx, 0], \"Full Input Frame\")\n", + " add_str(ax[2 * img_idx + 1, 0], \"Random Input Crop\")\n", + "\n", + " ax[2 * img_idx, 1].imshow(\n", + " noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\"\n", + " )\n", + " ax[2 * img_idx + 1, 1].imshow(\n", + " noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\"\n", + " )\n", + "\n", + " ax[2 * img_idx, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " ax[2 * img_idx, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " diff = pred2 - pred1\n", + " ax[2 * img_idx, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"coolwarm\")\n", + " ax[2 * img_idx + 1, 4].imshow(\n", + " diff[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"coolwarm\"\n", + " )\n", + "\n", + " ax[2 * img_idx, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\")\n", + " ax[2 * img_idx + 1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\")\n", + "\n", + " ax[2 * img_idx, 6].imshow(\n", + " highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap=\"magma\"\n", + " )\n", + " ax[2 * img_idx + 1, 6].imshow(\n", + " highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap=\"magma\"\n", + " )\n", + "\n", + " twinx = ax[2 * img_idx, 6].twinx()\n", + " twinx.set_ylabel(\"Channel 1\", size=15)\n", + " clean_ax(twinx)\n", + "\n", + " twinx = ax[2 * img_idx + 1, 6].twinx()\n", + " twinx.set_ylabel(\"Channel 2\", size=15)\n", + " clean_ax(twinx)\n", + "\n", + "ax[0, 0].set_title(\"Model Input\", size=15)\n", + "ax[0, 1].set_title(\"Target\", size=15)\n", + "ax[0, 2].set_title(\"Sample 1 (S1)\", size=15)\n", + "ax[0, 3].set_title(\"Sample 2 (S2)\", size=15)\n", + "ax[0, 4].set_title('\"S2\" - \"S1\"', size=15)\n", + "ax[0, 5].set_title(f\"Prediction MMSE({mmse_count})\", size=15)\n", + "ax[0, 6].set_title(\"High SNR Reality\", size=15)\n", + "\n", + "clean_ax(ax)\n", + "plt.subplots_adjust(wspace=0.02, hspace=0.02)\n", + "# plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "3db2fe0b", + "metadata": {}, + "source": [ + "
\n", + "

Questions:

\n", + " 1) When is it relatively easy to split the two structures from the input?
\n", + " 2) Why might you see the grid-like artifacts and what can be done to mitigate this?
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "ffbd27cb", + "metadata": { + "tags": [ + "solution" + ] + }, + "source": [ + "
\n", + "

Answers:

\n", + " 1) When there is less noise. Then things are more clearly present in the input. And, when the two structures are very different. It is easy to separate lines from dots. (CCP vs ER).
\n", + " 2) These are tiling artifacts, where the network assigns different background levels to neighboring patches. This becomes more prominent when the structures have haze. One way is to include more samples in the mmse.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "d94ea88e", + "metadata": { + "cell_marker": "\"\"\"", + "lines_to_next_cell": 0 + }, + "source": [ + "## Quantitative performance\n", + "We evaluate on two metrics, Multiscale SSIM and PSNR.\n", + "\n", + "Multi-scale SSIM is a metric that computes SSIM at multiple scales and averages them. It's reminiscent of multiscale processing in the early vision system \n", + "\n", + "PSNR is a metric that computes the peak signal-to-noise ratio. It's one of the most widely used metrics to measure the quality of image reconstruction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc3246db", + "metadata": {}, + "outputs": [], + "source": [ + "mean_tar = mean_dict[\"target\"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)\n", + "std_tar = std_dict[\"target\"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)\n", + "pred_unnorm = pred * std_tar + mean_tar\n", + "\n", + "psnr_list = [\n", + " avg_range_inv_psnr(highres_data[..., i].copy(), pred_unnorm[..., i].copy())\n", + " for i in range(highres_data.shape[-1])\n", + "]\n", + "ssim_list = compute_multiscale_ssim(highres_data.copy(), pred_unnorm.copy())\n", + "print(\"Metric: Ch1\\t Ch2\")\n", + "print(f\"PSNR : {psnr_list[0]:.2f}\\t {psnr_list[1]:.2f}\")\n", + "print(f\"MS-SSIM : {ssim_list[0]:.3f}\\t {ssim_list[1]:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ea1bdbbe", + "metadata": {}, + "source": [ + "

Checkpoint 2: Try one of the \"Several things to try\"

\n", + "\n", + "
\n", + "\n", + "Click [here](#things-to-try) to go back to the relevant section." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

End of the exercise

\n", + "
" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "all", + "main_language": "python" + }, + "kernelspec": { + "display_name": "05_image_restoration", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/05_bonus_Noise2Noise/n2n_exercise.ipynb b/05_bonus_Noise2Noise/n2n_exercise.ipynb new file mode 100644 index 0000000..9f4924e --- /dev/null +++ b/05_bonus_Noise2Noise/n2n_exercise.ipynb @@ -0,0 +1,336 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Noise2Noise\n", + "\n", + "CARE network you trained in the first restoration exercises require that you acquire pairs\n", + "of high and low SNR. However, this often not possible. One such case is when it is simply\n", + "not possible to acquire high SNR images.\n", + "\n", + "What to do when you are stuck with just noisy images? We also have seen Noise2Void, which\n", + "is a self-supervised method that can be trained on noisy images. But there are other \n", + "supervised approaches that can be trained on noisy images only, such as Noise2Noise. \n", + "\n", + "Noise2Noise relies on the same assumption than Noise2Void: the noise is pixel-independent.\n", + "Therefore, if you supervise your network to guess a noisy image from another one, the network\n", + "will converge to a denoised image. Of course, this only works if the two noisy images are\n", + "very similar.\n", + "\n", + "To acquire data for Noise2Noise, one can simply image the same region of interest twice!\n", + "Indeed, pixel-independent noise (as opposed to structured noise) will be completely independent\n", + "between neighboring pixels and as well as between the two noisy images.\n", + "\n", + "In this notebook, we will again use the [Careamics](https://careamics.github.io) library.\n", + "\n", + "## Reference\n", + "\n", + "Lehtinen, Jaakko, et al. \"[Noise2Noise: Learning image restoration without clean data.](https://arxiv.org/abs/1803.04189)\" arXiv preprint arXiv:1803.04189 (2018).\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Understand the differences between CARE, Noise2Noise and Noise2Void\n", + "- Train Noise2Noise with CAREamics\n", + " \n", + "
\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import tifffile\n", + "\n", + "from careamics import CAREamist\n", + "from careamics.config import create_n2n_configuration" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 1: Prepare the data\n", + "\n", + "The N2N SEM dataset consists of EM images with 7 different levels of noise:\n", + "\n", + "- Image 0 is recorded with 0.2 us scan time\n", + "- Image 1 is recorded with 0.5 us scan time\n", + "- Image 2 is recorded with 1 us scan time\n", + "- Image 3 is recorded with 1 us scan time\n", + "- Image 4 is recorded with 2.1 us scan time\n", + "- Image 5 is recorded with 5.0 us scan time\n", + "- Image 6 is recorded with 5.0 us scan time and is the avg. of 4 images\n", + "\n", + "Let's have a look at them.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training data\n", + "\n", + "In this cell we can see the different levels of noise in the SEM dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images\n", + "root_path = Path(\"./../data\")\n", + "train_image = tifffile.imread(root_path / \"denoising-N2N_SEM.unzip/SEM/train.tif\")\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "\n", + "# plot image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "ax[0].imshow(train_image[0,100:356, 500:756], cmap=\"gray\")\n", + "ax[0].set_title(\"Train image highest noise level\")\n", + "ax[1].imshow(train_image[-1, 100:356, 500:756], cmap=\"gray\")\n", + "ax[1].set_title(\"Train image lowest noise level\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 1: Explore the data

\n", + "\n", + "Visualize each different noise level!\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "## Part 2: Create the configuraion\n", + "\n", + "As in the Noise2Void exercise, a good CAREamics pipeline starts with a configuration!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_config = create_n2n_configuration(\n", + " experiment_name=\"N2N_SEM\",\n", + " data_type=\"array\",\n", + " axes=\"SYX\",\n", + " patch_size=[128, 128],\n", + " batch_size=128,\n", + " num_epochs=50,\n", + " logger=\"tensorboard\"\n", + ")\n", + "\n", + "# Visualize training configuration\n", + "print(training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 3: Train the network\n", + "\n", + "In this part, we create our training engine (`CAREamics`) and start training the network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create the engine\n", + "careamist = CAREamist(source=training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 2: Which data to choose?

\n", + "\n", + "How would you train a network to denoise images of 1 us scan time? Which images do you think could be used as input and which as target?\n", + "\n", + "Set the `train_source` and `train_target` accordingly and train the network.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the training data and targets\n", + "train_data = train_image[[2, 2, 2, 2, 2, 3, 3, 3, 3, 3], ...]\n", + "train_target = train_image[[0, 1, 3, 4, 5, 0, 1, 3, 4, 5], ...]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "careamist.train(\n", + " train_source=...,\n", + " train_target=...\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Training N2N

\n", + "
\n", + "\n", + "\n", + "
\n", + "\n", + "## Part 4: Prediction" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the test data and predict on it to assess how well the network performs!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images\n", + "test_image = tifffile.imread(root_path / \"denoising-N2N_SEM.unzip/SEM/test.tif\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction = careamist.predict(source=test_image[2], tile_size=(256, 256), axes=\"YX\", tta_transforms=False)[0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "ax[0].imshow(test_image[-1], cmap=\"gray\")\n", + "ax[0].set_title(\"Test image lowest noise level\")\n", + "ax[1].imshow(prediction[0, 0], cmap=\"gray\")\n", + "ax[1].set_title(\"Prediction\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fi, ax = plt.subplots(1, 2, figsize=(15, 15))\n", + "vim = test_image[0].min()\n", + "vmax = test_image[0].max()\n", + "ax[0].imshow((prediction.squeeze())[1000:1128, 500:628], cmap=\"gray\",vmin=vim, vmax=vmax)\n", + "ax[0].set_title(\"Prediction\")\n", + "ax[1].imshow(test_image[-1].squeeze()[1000:1128, 500:628], cmap=\"gray\", vmin=vim, vmax=vmax)\n", + "ax[1].set_title(\"Test image lowest noise level\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 3: Different noise pairs

\n", + "\n", + "Can you further improve your results by usign different `source` and `target`?\n", + "\n", + "How would you train a network to denoise all images, rather than just the 1 us ones?\n", + "\n", + "Try it and be creative!\n", + "\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cmcs_l", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/05_bonus_Noise2Noise/n2n_solution.ipynb b/05_bonus_Noise2Noise/n2n_solution.ipynb new file mode 100755 index 0000000..68b6ea3 --- /dev/null +++ b/05_bonus_Noise2Noise/n2n_solution.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Noise2Noise\n", + "\n", + "CARE network you trained in the first restoration exercises require that you acquire pairs\n", + "of high and low SNR. However, this often not possible. One such case is when it is simply\n", + "not possible to acquire high SNR images.\n", + "\n", + "What to do when you are stuck with just noisy images? We also have seen Noise2Void, which\n", + "is a self-supervised method that can be trained on noisy images. But there are other \n", + "supervised approaches that can be trained on noisy images only, such as Noise2Noise. \n", + "\n", + "Noise2Noise relies on the same assumption than Noise2Void: the noise is pixel-independent.\n", + "Therefore, if you supervise your network to guess a noisy image from another one, the network\n", + "will converge to a denoised image. Of course, this only works if the two noisy images are\n", + "very similar.\n", + "\n", + "To acquire data for Noise2Noise, one can simply image the same region of interest twice!\n", + "Indeed, pixel-independent noise (as opposed to structured noise) will be completely independent\n", + "between neighboring pixels and as well as between the two noisy images.\n", + "\n", + "In this notebook, we will again use the [Careamics](https://careamics.github.io) library.\n", + "\n", + "## Reference\n", + "\n", + "Lehtinen, Jaakko, et al. \"[Noise2Noise: Learning image restoration without clean data.](https://arxiv.org/abs/1803.04189)\" arXiv preprint arXiv:1803.04189 (2018).\n", + "\n", + "\n", + "

Objectives

\n", + " \n", + "- Understand the differences between CARE, Noise2Noise and Noise2Void\n", + "- Train Noise2Noise with CAREamics\n", + " \n", + "
\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Set your python kernel to 05_image_restoration\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import tifffile\n", + "\n", + "from careamics import CAREamist\n", + "from careamics.config import create_n2n_configuration" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 1: Prepare the data\n", + "\n", + "The N2N SEM dataset consists of EM images with 7 different levels of noise:\n", + "\n", + "- Image 0 is recorded with 0.2 us scan time\n", + "- Image 1 is recorded with 0.5 us scan time\n", + "- Image 2 is recorded with 1 us scan time\n", + "- Image 3 is recorded with 1 us scan time\n", + "- Image 4 is recorded with 2.1 us scan time\n", + "- Image 5 is recorded with 5.0 us scan time\n", + "- Image 6 is recorded with 5.0 us scan time and is the avg. of 4 images\n", + "\n", + "Let's have a look at them.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize training data\n", + "\n", + "In this cell we can see the different levels of noise in the SEM dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images\n", + "root_path = Path(\"./../data\")\n", + "train_image = tifffile.imread(root_path / \"denoising-N2N_SEM.unzip/SEM/train.tif\")\n", + "print(f\"Train image shape: {train_image.shape}\")\n", + "\n", + "# plot image\n", + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "ax[0].imshow(train_image[0,100:356, 500:756], cmap=\"gray\")\n", + "ax[0].set_title(\"Train image highest noise level\")\n", + "ax[1].imshow(train_image[-1, 100:356, 500:756], cmap=\"gray\")\n", + "ax[1].set_title(\"Train image lowest noise level\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 1: Explore the data

\n", + "\n", + "Visualize each different noise level!\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "## Part 2: Create the configuraion\n", + "\n", + "As in the Noise2Void exercise, a good CAREamics pipeline starts with a configuration!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "training_config = create_n2n_configuration(\n", + " experiment_name=\"N2N_SEM\",\n", + " data_type=\"array\",\n", + " axes=\"SYX\",\n", + " patch_size=[128, 128],\n", + " batch_size=128,\n", + " num_epochs=50,\n", + " logger=\"tensorboard\"\n", + ")\n", + "\n", + "# Visualize training configuration\n", + "print(training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "## Part 3: Train the network\n", + "\n", + "In this part, we create our training engine (`CAREamics`) and start training the network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create the engine\n", + "careamist = CAREamist(source=training_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 2: Which data to choose?

\n", + "\n", + "How would you train a network to denoise images of 1 us scan time? Which images do you think could be used as input and which as target?\n", + "\n", + "Set the `train_source` and `train_target` accordingly and train the network.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the training data and targets\n", + "train_data = train_image[[2, 2, 2, 2, 2, 3, 3, 3, 3, 3], ...]\n", + "train_target = train_image[[0, 1, 3, 4, 5, 0, 1, 3, 4, 5], ...]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "task" + ] + }, + "outputs": [], + "source": [ + "careamist.train(\n", + " train_source=...,\n", + " train_target=...\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "careamist.train(\n", + " train_source=train_data,\n", + " train_target=train_target\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Checkpoint 1: Training N2N

\n", + "
\n", + "\n", + "\n", + "
\n", + "\n", + "## Part 4: Prediction" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the test data and predict on it to assess how well the network performs!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load images\n", + "test_image = tifffile.imread(root_path / \"denoising-N2N_SEM.unzip/SEM/test.tif\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction = careamist.predict(source=test_image[2], tile_size=(256, 256), axes=\"YX\", tta_transforms=False)[0]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "ax[0].imshow(test_image[-1], cmap=\"gray\")\n", + "ax[0].set_title(\"Test image lowest noise level\")\n", + "ax[1].imshow(prediction[0, 0], cmap=\"gray\")\n", + "ax[1].set_title(\"Prediction\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fi, ax = plt.subplots(1, 2, figsize=(15, 15))\n", + "vim = test_image[0].min()\n", + "vmax = test_image[0].max()\n", + "ax[0].imshow((prediction.squeeze())[1000:1128, 500:628], cmap=\"gray\",vmin=vim, vmax=vmax)\n", + "ax[0].set_title(\"Prediction\")\n", + "ax[1].imshow(test_image[-1].squeeze()[1000:1128, 500:628], cmap=\"gray\", vmin=vim, vmax=vmax)\n", + "ax[1].set_title(\"Test image lowest noise level\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Task 3: Different noise pairs

\n", + "\n", + "Can you further improve your results by usign different `source` and `target`?\n", + "\n", + "How would you train a network to denoise all images, rather than just the 1 us ones?\n", + "\n", + "Try it and be creative!\n", + "\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cmcs_l", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..13a1dc5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2024, DL4MIA + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md old mode 100644 new mode 100755 index b180df1..c1ffa20 --- a/README.md +++ b/README.md @@ -1,69 +1,34 @@ -# Exercise 3: Image Restoration +# Image Restoration: denoising and splitting -In this exercise you will get to try out some of the image restoration techniques that you just learned about in the lecture. -Start by setting up all the environments and downloading the example data: Open a terminal and run `source setup.sh`. +Welcome to the Image Restoration exercises. In this part of the course, we will explore +how to use deep learning to denoise images, with examples of widely used algorithm for +both supervised and unsupervised denoising. We will also explore the difference +between unstructured and structured noise, or between UNet (which you are familiar with +by now) and VAE architectures (see COSDD exercise)! -In the first part of the exercise `exercise1.ipynb` you will use paired images with high and low signal to noise ratios to train a supervised CARE network. The second part (`exercise2.ipynb`) you will train a Noise2Noise network with multiple SEM acquisitions of the same sample at various noise levels. And in part 3 (`exercise3.ipynb`) you can train a Noise2Void network on your own data (if you'd like). +Finally, we have bonus exercises for those wanted to explore more denoising algorithms or +image splitting! -All exercise notebooks are closely modeled after example notebooks that were provided as part of the respective repositories by their authors. This won't always be the case but think of this exercise as a good example of the situation you'll find yourself in if you find a deep learning method "in the wild" that you would like to try out yourself. -If you have extra time in the end check out `exercise_bonus1.ipynb` if you're interested in Probabilistic Noise2Void or `exercise_bonus2.md` for DivNoising where you will go one step further by cloning the repo yourself, setting up your own environment and running an example notebook from the repo. +## Setup +Please run the setup script to create the environment for these exercises and download data. +``` bash +source setup.sh +``` +## Exercises +1. [Context-aware restoration](01_CARE/care_exercise.ipynb) +2. [Noise2Void](02_Noise2Void/n2v_exercise.ipynb) +3. [Correlated and Signal Dependent Denoising (COSDD)](03_COSDD/exercise.ipynb) +4. [DenoiSplit](04_DenoiSplit/exercise.ipynb) +## Bonus +- [Noise2Noise](05_bonus_Noise2Noise/n2n_exercise.ipynb) -## Task Overview for TAs - -### Exercise1 -#### Questions: -- where is the training data located? -- how is the data organized to identify the pairs of HR and LR images? -#### Questions: -- Where are the trained models stored? What models are being stored, how do they differ? -- How does the name of the saved models get specified? -- How can you influence the number of training steps per epoch? What did you use? - --> CHECKPOINT1 - -### Exercise2 -#### Task 2.1: -- Crop image for visualization to get a feeling for what the data looks like -#### Task 2.2: -- Pick input and target images for Noise2Noise -#### Task 2.3: -- create raw data object -#### Task 2.4: -- write function that applies the model to one of the images -#### Task 2.5: -- play around by tweaking setup and/or train network for all scan times --> CHECKPOINT2 - -### Exercise3 -#### Task 3.1: -- use your own data for N2V -#### Task 3.2: -- configure N2V model -#### Task 3.3: -- measure performance (if high SNR image available) --> CHECKPOINT3 - -### Bonus Exercise 1 -#### Task 4.1 -- estimate clean signal from calibration data -#### Task 4.2 -- create histogram from calibration data -#### Task 4.3 -- create histogram from bootstrapped signal -#### Task 4.4 -- train PN2V model -#### Task 4.5 -- try PN2V model for your own data - -### Bonus Exercise 2 -- run an example notebook from DivNoising repo diff --git a/download_careamics_portfolio.py b/download_careamics_portfolio.py new file mode 100644 index 0000000..3bd8a9e --- /dev/null +++ b/download_careamics_portfolio.py @@ -0,0 +1,10 @@ +from pathlib import Path + +from careamics_portfolio import PortfolioManager + + +portfolio = PortfolioManager() +root_path = Path("./data") +portfolio.denoising.N2V_SEM.download(root_path) +portfolio.denoising.CARE_U2OS.download(root_path) +portfolio.denoising.N2N_SEM.download(root_path) diff --git a/exercise1.ipynb b/exercise1.ipynb deleted file mode 100644 index 784359e..0000000 --- a/exercise1.ipynb +++ /dev/null @@ -1,674 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7a74d84f", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "\n", - "# Train your first CARE model (supervised)\n", - "\n", - "In this first example we will train a CARE model for a 2D denoising and upsampling task, where corresponding pairs of low and high signal-to-noise ratio (SNR) images of cells are available. Here the high SNR images are acquisitions of Human U2OS cells taken from the [Broad Bioimage Benchmark Collection](https://data.broadinstitute.org/bbbc/BBBC006/) and the low SNR images were created by synthetically adding *strong read-out and shot-noise* and applying *pixel binning* of 2x2, thus mimicking acquisitions at a very low light level.\n", - "\n", - "![](nb_material/denoising_binning_overview.png)\n", - "\n", - "\n", - "For CARE, image pairs should be registered, which in practice is best achieved by acquiring both stacks _interleaved_, i.e. as different channels that correspond to the different exposure/laser settings.\n", - "\n", - "Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so typically need to be co-registered before training a CARE model.\n", - "\n", - "To train a denoising network, we will use the [CSBDeep Repo](https://github.com/CSBDeep/CSBDeep). This notebook has a very similar structure to the examples you can find there.\n", - "More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.\n", - "\n", - "This part will not have any coding tasks, but go through each cell and try to understand what's going on - it will help you in the next part! We also put some questions along the way. For some of them you might need to dig a bit deeper.\n", - "\n", - "
\n", - "Set your python kernel to 03_image_restoration_part1\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7cef58ce", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import os\n", - "import numpy as np\n", - "from csbdeep.data import (\n", - " RawData,\n", - " create_patches,\n", - " no_background_patches,\n", - " norm_percentiles,\n", - " sample_percentiles,\n", - ")\n", - "from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n", - "from csbdeep.models import CARE, Config\n", - "from csbdeep.utils import (\n", - " Path,\n", - " axes_dict,\n", - " normalize,\n", - " plot_history,\n", - " plot_some,\n", - ")\n", - "from csbdeep.utils.tf import limit_gpu_memory\n", - "\n", - "%matplotlib inline\n", - "%load_ext tensorboard\n", - "%config InlineBackend.figure_format = 'retina'\n", - "from tifffile import imread" - ] - }, - { - "cell_type": "markdown", - "id": "ef4f4c66", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 1: Training Data Generation\n", - "Network training usually happens on batches of smaller sized images than the ones recorded on a microscopy. In this first part of the exercise, we will load all of the image data and chop it into smaller pieces, a.k.a. patches.\n", - "\n", - "### Look at example data\n", - "\n", - "During setup, we downloaded some example data, consisting of low-SNR and high-SNR 3D images of Tribolium.\n", - "Note that `GT` stands for ground truth and represents high signal-to-noise ratio (SNR) stacks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02cea28e", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "assert os.path.exists(\"data/U2OS\")" - ] - }, - { - "cell_type": "markdown", - "id": "f8bd941c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "As we can see, the data set is already split into a **train** and **test** set, each containing (synthetically generated) low SNR (\"low\") and corresponding high SNR (\"GT\") images.\n", - "\n", - "Let's look at an example pair of training images:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a2864715", - "metadata": {}, - "outputs": [], - "source": [ - "y = imread(\"data/U2OS/train/GT/img_0010.tif\")\n", - "x = imread(\"data/U2OS/train/low/img_0010.tif\")\n", - "print(\"GT image size =\", x.shape)\n", - "print(\"low-SNR image size =\", y.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17f182a2", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(13, 5))\n", - "plt.subplot(1, 2, 1)\n", - "plt.imshow(x, cmap=\"magma\")\n", - "plt.colorbar()\n", - "plt.title(\"low\")\n", - "plt.subplot(1, 2, 2)\n", - "plt.imshow(y, cmap=\"magma\")\n", - "plt.colorbar()\n", - "plt.title(\"high\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "e78564f8", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Generate training data for CARE\n", - "\n", - "We first need to create a `RawData` object, which defines how to get the pairs of low/high SNR stacks and the semantics of each axis (e.g. which one is considered a color channel, etc.). In general the names for the axes are:\n", - "\n", - "X: columns, Y: rows, Z: planes, C: channels, T: frames/time, (S: samples/images)\n", - "\n", - "Here we have two folders \"low\" and \"GT\", where corresponding low and high-SNR stacks are TIFF images with identical filenames.\n", - "\n", - "For this case, we can simply use `RawData.from_folder` and set `axes = 'YX'` to indicate the semantic order of the image axes, i.e. we have two-dimensional images in standard xy layout." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6f9c6cdd", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "raw_data = RawData.from_folder(\n", - " basepath=\"data/U2OS/train\",\n", - " source_dirs=[\"low\"],\n", - " target_dir=\"GT\",\n", - " axes=\"YX\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3cef3683", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "From corresponding images, we now generate some 2D patches to use for training.\n", - "\n", - "As a general rule, use a *patch size* that is a power of two along all axes, or at least divisible by 8. Typically, you should use more patches the more trainings images you have.\n", - "\n", - "An important aspect is *data normalization*, i.e. the rescaling of corresponding patches to a dynamic range of ~ (0,1). By default, this is automatically provided via percentile normalization, which can be adapted if needed.\n", - "\n", - "By default, patches are sampled from *non-background regions* (i.e. that are above a relative threshold). We will disable this for the current example as most image regions already contain foreground pixels and thus set the threshold to 0. See the documentation of `create_patches` for details.\n", - "\n", - "Note that returned values `(X, Y, XY_axes)` by `create_patches` are not to be confused with the image axes X and Y. By convention, the variable name X (or x) refers to an input variable for a machine learning model, whereas Y (or y) indicates an output variable." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44e72ed6", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y, XY_axes = create_patches(\n", - " raw_data=raw_data,\n", - " patch_size=(128, 128),\n", - " patch_filter=no_background_patches(0),\n", - " n_patches_per_image=2,\n", - " save_file=\"data/U2OS/my_training_data.npz\",\n", - ")\n", - "\n", - "assert X.shape == Y.shape\n", - "print(\"shape of X,Y =\", X.shape)\n", - "print(\"axes of X,Y =\", XY_axes)" - ] - }, - { - "cell_type": "markdown", - "id": "eb760f8a", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Show\n", - "\n", - "This shows some of the generated patch pairs (odd rows: *source*, even rows: *target*)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7eeb090", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "for i in range(2):\n", - " plt.figure(figsize=(16, 4))\n", - " sl = slice(8 * i, 8 * (i + 1)), 0\n", - " plot_some(\n", - " X[sl], Y[sl], title_list=[np.arange(sl[0].start, sl[0].stop)]\n", - " ) # convenience function provided by CSB Deep\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "cf8015ae", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "

\n", - " Questions:

\n", - "
    \n", - "
  1. Where is the training data located?
  2. \n", - "
  3. How is the data organized to identify the pairs of HR and LR images?
  4. \n", - "
\n", - "
\n", - "\n", - "
\n", - "\n", - "## Part 2: Training the network\n", - "\n", - "\n", - "### Load Training data\n", - "\n", - "Load the patches generated in part 1, use 10% as validation data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc8ab655", - "metadata": {}, - "outputs": [], - "source": [ - "(X, Y), (X_val, Y_val), axes = load_training_data(\n", - " \"data/U2OS/my_training_data.npz\", validation_split=0.1, verbose=True\n", - ")\n", - "\n", - "c = axes_dict(axes)[\"C\"]\n", - "n_channel_in, n_channel_out = X.shape[c], Y.shape[c]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e344b63d", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 5))\n", - "plot_some(X_val[:5], Y_val[:5])\n", - "plt.suptitle(\"5 example validation patches (top row: source, bottom row: target)\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "dbc0df28", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Configure the CARE model\n", - "Before we construct the actual CARE model, we have to define its configuration via a `Config` object, which includes\n", - "* parameters of the underlying neural network,\n", - "* the learning rate,\n", - "* the number of parameter updates per epoch,\n", - "* the loss function, and\n", - "* whether the model is probabilistic or not.\n", - "\n", - "![](nb_material/carenet.png)\n", - "\n", - "The defaults should be sensible in many cases, so a change should only be necessary if the training process fails.\n", - "\n", - "Important: Note that for this notebook we use a very small number of update steps for immediate feedback, whereas the number of epochs and steps per epoch should be increased considerably (e.g. `train_steps_per_epoch=400`, `train_epochs=100`) to obtain a well-trained model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f6536160", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "config = Config(\n", - " axes,\n", - " n_channel_in,\n", - " n_channel_out,\n", - " train_batch_size=8,\n", - " train_steps_per_epoch=40,\n", - " train_epochs=20,\n", - ")\n", - "vars(config)" - ] - }, - { - "cell_type": "markdown", - "id": "9aecabf8", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We now create a CARE model with the chosen configuration:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5dc0fcf5", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = CARE(config, \"my_CARE_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "c99e6540", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We can get a summary of all the layers in the model and the number of parameters:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f73a5754", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model.keras_model.summary()" - ] - }, - { - "cell_type": "markdown", - "id": "2f15a9d1", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Training\n", - "\n", - "Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training.\n", - "Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.\n", - "\n", - "We can start tensorboard within the notebook.\n", - "\n", - "Alternatively, you can launch the notebook in an independent tab by changing the `%` to `!`\n", - "
\n", - "If you're using ssh add --host <hostname> to the command:\n", - "! tensorboard --logdir models --host <hostname> where <hostname> is the thing that ends in amazonaws.com.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0409db29", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "%tensorboard --logdir models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a011e9e7", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "history = model.train(X, Y, validation_data=(X_val, Y_val))" - ] - }, - { - "cell_type": "markdown", - "id": "1a945056", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Plot final training history (available in TensorBoard during training):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7328eec2", - "metadata": {}, - "outputs": [], - "source": [ - "print(sorted(list(history.history.keys())))\n", - "plt.figure(figsize=(16, 5))\n", - "plot_history(history, [\"loss\", \"val_loss\"], [\"mse\", \"val_mse\", \"mae\", \"val_mae\"])" - ] - }, - { - "cell_type": "markdown", - "id": "7bccaf8b", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Evaluation\n", - "Example results for validation images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b2914d3", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 7))\n", - "_P = model.keras_model.predict(X_val[:5])\n", - "if config.probabilistic:\n", - " _P = _P[..., : (_P.shape[-1] // 2)]\n", - "plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5)\n", - "plt.suptitle(\n", - " \"5 example validation patches\\n\"\n", - " \"top row: input (source), \"\n", - " \"middle row: target (ground truth), \"\n", - " \"bottom row: predicted from source\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "11e9d20e", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "

\n", - " Questions:

\n", - "
    \n", - "
  1. Where are trained models stored? What models are being stored, how do they differ?
  2. \n", - "
  3. How does the name of the saved models get specified?
  4. \n", - "
  5. How can you influence the number of training steps per epoch? What did you use?
  6. \n", - "
\n", - "
\n", - "\n", - "
\n", - "\n", - "## Part 3: Prediction\n", - "\n", - "Plot the test stack pair and define its image axes, which will be needed later for CARE prediction." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1dc0da6", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "y_test = imread(\"data/U2OS/test/GT/img_0010.tif\")\n", - "x_test = imread(\"data/U2OS/test/low/img_0010.tif\")\n", - "\n", - "axes = \"YX\"\n", - "print(\"image size =\", x_test.shape)\n", - "print(\"image axes =\", axes)\n", - "\n", - "plt.figure(figsize=(16, 10))\n", - "plot_some(np.stack([x_test, y_test]), title_list=[[\"low\", \"high\"]])" - ] - }, - { - "cell_type": "markdown", - "id": "2d4c79ec", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Load CARE model\n", - "\n", - "Load trained model (located in base directory `models` with name `my_CARE_model`) from disk.\n", - "The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94c55e51", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = CARE(config=None, name=\"my_CARE_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "4196e6ea", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Apply CARE network to raw image\n", - "Predict the restored image (image will be successively split into smaller tiles if there are memory issues)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12c86090", - "metadata": {}, - "outputs": [], - "source": [ - "%%time\n", - "restored = model.predict(x_test, axes)" - ] - }, - { - "cell_type": "markdown", - "id": "70ecb045", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Save restored image\n", - "\n", - "Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2846eed", - "metadata": {}, - "outputs": [], - "source": [ - "Path(\"results\").mkdir(exist_ok=True)\n", - "save_tiff_imagej_compatible(\"results/%s_img_0010.tif\" % model.name, restored, axes)" - ] - }, - { - "cell_type": "markdown", - "id": "98c3d39b", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "### Visualize results\n", - "Plot the test stack pair and the predicted restored stack (middle)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1f8239d3", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(15, 10))\n", - "plot_some(\n", - " np.stack([x_test, restored, y_test]),\n", - " title_list=[[\"low\", \"CARE\", \"GT\"]],\n", - " pmin=2,\n", - " pmax=99.8,\n", - ")\n", - "\n", - "plt.figure(figsize=(10, 5))\n", - "for _x, _name in zip((x_test, restored, y_test), (\"low\", \"CARE\", \"GT\")):\n", - " plt.plot(normalize(_x, 1, 99.7)[180], label=_name, lw=2)\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "dc9bc434", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have reached the first checkpoint of this exercise! Please mark your progress in the course chat!\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/exercise2.ipynb b/exercise2.ipynb deleted file mode 100644 index 0a99ae9..0000000 --- a/exercise2.ipynb +++ /dev/null @@ -1,803 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a3ef0059", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train a Noise2Noise network with CARE\n", - "
\n", - "Set your python kernel to 03_image_restoration_part1! That's the same as for the first notebook.\n", - "
\n", - "\n", - "We will now train a 2D Noise2Noise network using CARE. We will closely follow along the previous example but now you will have to fill in some parts on your own!\n", - "You will have to make decisions - make them!\n", - "\n", - "But first some clean up...\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
\n", - "\n", - "![](nb_material/notebook_shutdown.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cd71b96a", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "import gc\n", - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from csbdeep.data import RawData, create_patches\n", - "from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n", - "from csbdeep.models import CARE, Config\n", - "from csbdeep.utils import (\n", - " Path,\n", - " axes_dict,\n", - " plot_history,\n", - " plot_some,\n", - ")\n", - "from csbdeep.utils.tf import limit_gpu_memory\n", - "\n", - "%matplotlib inline\n", - "%load_ext tensorboard\n", - "%config InlineBackend.figure_format = 'retina'\n", - "from skimage.metrics import peak_signal_noise_ratio, structural_similarity\n", - "from tifffile import imread, imwrite" - ] - }, - { - "cell_type": "markdown", - "id": "a04d9ec0", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 1: Training Data Generation\n", - "\n", - "### Download example data\n", - "\n", - "To train a Noise2Noise setup we need several acquisitions of the same sample.\n", - "The SEM data we downloaded during setup contains 2 tiff-stacks, one for training and one for testing, let's make sure it's there!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3eacfb41", - "metadata": {}, - "outputs": [], - "source": [ - "assert os.path.exists(\"data/SEM/train/train.tif\")\n", - "assert os.path.exists(\"data/SEM/test/test.tif\")" - ] - }, - { - "cell_type": "markdown", - "id": "2a486bc3", - "metadata": {}, - "source": [ - "Let's have a look at the data!\n", - "Each image is a tiff stack containing 7 images of the same tissue recorded with different scan time settings of a Scanning Electron Miscroscope (SEM). The faster a SEM image is scanned, the noisier it gets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5fbcf59e", - "metadata": {}, - "outputs": [], - "source": [ - "imgs = imread(\"data/SEM/train/train.tif\")\n", - "x_size = imgs.shape\n", - "print(\"image size =\", x_size)\n", - "scantimes_all = [\"0.2us\", \"0.5us\", \"1us\", \"1us\", \"2.1us\", \"5us\", \"5us, avg of 4\"]\n", - "plt.figure(figsize=(40, 16))\n", - "plot_some(imgs, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap=\"gray_r\")" - ] - }, - { - "cell_type": "markdown", - "id": "e13f36f4", - "metadata": {}, - "source": [ - "---\n", - "

\n", - " TASK 2.1:

\n", - "

\n", - " The noise level is hard to see at this zoom level. Let's also look at a smaller crop of them! Play around with this until you have a feeling for what the data looks like.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86ddce74", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "\n", - "imgs_cropped = ... # TODO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "59a780db", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(40, 16))\n", - "plot_some(imgs_cropped, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap=\"gray_r\")" - ] - }, - { - "cell_type": "markdown", - "id": "d0223757", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "249ff869", - "metadata": {}, - "outputs": [], - "source": [ - "# checking that you didn't crop x_train itself, we still need that!\n", - "assert imgs.shape == x_size" - ] - }, - { - "cell_type": "markdown", - "id": "97253add", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "As you can see the last image, which is the average of 4 images with 5$\\mu s$ scantime, has the highest signal-to-noise-ratio. It is not noise-free but our best choice to be able to compare our results against quantitatively, so we will set it aside for that purpose." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "82165898", - "metadata": {}, - "outputs": [], - "source": [ - "scantimes, scantime_highSNR = scantimes_all[:-1], scantimes_all[-1]\n", - "x_train, x_highSNR = imgs[:-1], imgs[-1]\n", - "print(scantimes, scantime_highSNR)\n", - "print(x_train.shape, x_highSNR.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "c904033d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Generate training data for CARE\n", - "\n", - "Let's try and train a network to denoise images of $1 \\mu s$ scan time!\n", - "Which images do you think could be used as input and which as target?\n", - "\n", - "---\n", - "

\n", - " TASK 2.2:

\n", - "

\n", - " Decide which images to use as inputs and which as targets. Then, remember from part one how the data has to be organized to match up inputs and targets.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5dce0b0b", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "base_path = \"data/SEM/train\"\n", - "source_dir = os.path.join(base_path, \"\") # pick path in which to save inputs\n", - "target_dir = os.path.join(base_path, \"\") # pick path in which to save targets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d9b0181", - "metadata": {}, - "outputs": [], - "source": [ - "os.makedirs(source_dir, exist_ok=True)\n", - "os.makedirs(target_dir, exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92fff631", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Now save individual images into these directories\n", - "# You can use the imwrite function to save images. The ? command will pull up the docstring\n", - "?imwrite" - ] - }, - { - "cell_type": "markdown", - "id": "f426a521", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Hint: The tiff file you read earlier contained 7 images for the different instances. Here, use a single tiff file per image." - ] - }, - { - "cell_type": "markdown", - "id": "ac8d428c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Hint: Remember we're trying to train a Noise2Noise network here, so the target does not need to be clean." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8701a98", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "\n", - "# Put the pairs of input and target images into the `source_dir` and `target_dir`, respectively.\n", - "# The goal here is to the train a network for 1 us scan time." - ] - }, - { - "cell_type": "markdown", - "id": "dfc0f4ae", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "---\n", - "

\n", - " TASK 2.3:

\n", - "

\n", - " Now that you arranged the training data we can now create the raw data object.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f048fbce", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "raw_data = RawData.from_folder(\n", - " basepath=\"data/SEM/train\",\n", - " source_dirs=[\"\"], # fill in your directory for source images\n", - " target_dir=\"\", # fill in your directory of target images\n", - " axes=\"\", # what should the axes tag be?\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "86a23463", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "We generate 2D patches. If you'd like, you can play around with the parameters here." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0ee336", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y, XY_axes = create_patches(\n", - " raw_data=raw_data,\n", - " patch_size=(256, 256),\n", - " n_patches_per_image=512,\n", - " save_file=\"data/SEM/my_1us_training_data.npz\",\n", - ")\n", - "\n", - "assert X.shape == Y.shape\n", - "print(\"shape of X,Y =\", X.shape)\n", - "print(\"axes of X,Y =\", XY_axes)" - ] - }, - { - "cell_type": "markdown", - "id": "daf15a26", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Show\n", - "\n", - "Let's look at some of the generated patch pairs. (odd rows: _source_, even rows: _target_)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6227c8fe", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "for i in range(2):\n", - " plt.figure(figsize=(16, 4))\n", - " sl = slice(8 * i, 8 * (i + 1)), 0\n", - " plot_some(\n", - " X[sl], Y[sl], title_list=[np.arange(sl[0].start, sl[0].stop)], cmap=\"gray_r\"\n", - " )\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fbaf33e4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 2: Training the network\n", - "\n", - "\n", - "### Load Training data\n", - "\n", - "Load the patches generated in part 1, use 10% as validation data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef2231ad", - "metadata": {}, - "outputs": [], - "source": [ - "(X, Y), (X_val, Y_val), axes = load_training_data(\n", - " \"data/SEM/my_1us_training_data.npz\", validation_split=0.1, verbose=True\n", - ")\n", - "\n", - "c = axes_dict(axes)[\"C\"]\n", - "n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n", - "\n", - "\n", - "plt.figure(figsize=(12, 5))\n", - "plot_some(X_val[:5], Y_val[:5], cmap=\"gray_r\", pmin=0.2, pmax=99.8)\n", - "plt.suptitle(\"5 example validation patches (top row: source, bottom row: target)\")\n", - "\n", - "config = Config(\n", - " axes, n_channel_in, n_channel_out, train_steps_per_epoch=10, train_epochs=100\n", - ")\n", - "vars(config)" - ] - }, - { - "cell_type": "markdown", - "id": "c53ca47d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We now create a CARE model with the chosen configuration:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "386877f3", - "metadata": {}, - "outputs": [], - "source": [ - "model = CARE(config, \"my_N2N_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "4a170adb", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Training\n", - "\n", - "Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training.\n", - "Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.\n", - "\n", - "Start tensorboard as you did in the previous notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "383cc0fb", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "%tensorboard --logdir models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afd5ce7b", - "metadata": {}, - "outputs": [], - "source": [ - "history = model.train(X, Y, validation_data=(X_val, Y_val))" - ] - }, - { - "cell_type": "markdown", - "id": "242c2a9c", - "metadata": {}, - "source": [ - "Plot final training history (available in TensorBoard during training):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7607957c", - "metadata": {}, - "outputs": [], - "source": [ - "print(sorted(list(history.history.keys())))\n", - "plt.figure(figsize=(16, 5))\n", - "plot_history(history, [\"loss\", \"val_loss\"], [\"mse\", \"val_mse\", \"mae\", \"val_mae\"])" - ] - }, - { - "cell_type": "markdown", - "id": "a8b12c16", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Evaluation\n", - "Example results for validation images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7d920b92", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 7))\n", - "_P = model.keras_model.predict(X_val[:5])\n", - "if config.probabilistic:\n", - " _P = _P[..., : (_P.shape[-1] // 2)]\n", - "plot_some(X_val[:5], Y_val[:5], _P, pmin=0.2, pmax=99.8, cmap=\"gray_r\")\n", - "plt.suptitle(\n", - " \"5 example validation patches\\n\"\n", - " \"top row: input (noisy source), \"\n", - " \"mid row: target (independently noisy), \"\n", - " \"bottom row: predicted from source, \"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "72321ef2", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 3: Prediction\n", - "\n", - "\n", - "### Load CARE model\n", - "\n", - "Load trained model (located in base directory `models` with name `my_model`) from disk.\n", - "The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dbdb29ac", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = CARE(config=None, name=\"my_N2N_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "ee7ffaf8", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Apply CARE network to raw image\n", - "Now use the trained model to denoise some test images. Let's load the whole tiff stack first" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c6c2f73d", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "path_test_data = \"data/SEM/test/test.tif\"\n", - "test_imgs = imread(path_test_data)\n", - "axes = \"YX\"\n", - "\n", - "# separate out the high SNR image as before\n", - "x_test, x_test_highSNR = test_imgs[:-1], test_imgs[-1]" - ] - }, - { - "cell_type": "markdown", - "id": "0112bf1b", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "

\n", - " TASK 2.4:

\n", - "

\n", - " Write a function that applies the model to one of the images in the tiff stack. Code to visualize the result by plotting the noisy image alongside the restored image as well as smaller crops of each is provided.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbc45be7", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "def apply_on_test(predict_model, img_idx, plot=True):\n", - " \"\"\"\n", - " Apply the given model on the test image at the given index of the tiff stack.\n", - " Returns the noisy image, restored image and the scantime.\n", - " \"\"\"\n", - " # TODO: insert your code for prediction here\n", - " scantime = ... # get scantime for `img_idx`th image\n", - " img = ... # get `img_idx`th image\n", - " restored = ... # apply model to `img`\n", - " if plot:\n", - " img_crop = img[500:756, 200:456]\n", - " restored_crop = restored[500:756, 200:456]\n", - " x_test_highSNR_crop = x_test_highSNR[500:756, 200:456]\n", - " plt.figure(figsize=(20, 30))\n", - " plot_some(\n", - " np.stack([img, restored, x_test_highSNR]),\n", - " np.stack([img_crop, restored_crop, x_test_highSNR_crop]),\n", - " cmap=\"gray_r\",\n", - " title_list=[[scantime, \"restored\", scantime_highSNR]],\n", - " )\n", - " return img, restored, scantime" - ] - }, - { - "cell_type": "markdown", - "id": "770d410b", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "\n", - "Using the function you just wrote to restore one of the images with 1us scan time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e780e06", - "metadata": {}, - "outputs": [], - "source": [ - "noisy_img, restored_img, scantime = apply_on_test(model, 2)\n", - "\n", - "ssi_input = structural_similarity(noisy_img, x_test_highSNR, data_range=65535)\n", - "ssi_restored = structural_similarity(restored_img, x_test_highSNR, data_range=65535)\n", - "print(\n", - " f\"Structural similarity index (higher is better) wrt average of 4x5us images: \\n\"\n", - " f\"Input: {ssi_input} \\n\"\n", - " f\"Prediction: {ssi_restored}\"\n", - ")\n", - "\n", - "psnr_input = peak_signal_noise_ratio(noisy_img, x_test_highSNR, data_range=65535)\n", - "psnr_restored = peak_signal_noise_ratio(restored_img, x_test_highSNR, data_range=65535)\n", - "print(\n", - " f\"Peak signal-to-noise ratio wrt average of 4x5us images:\\n\"\n", - " f\"Input: {psnr_input} \\n\"\n", - " f\"Prediction: {psnr_restored}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b268fafe", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "

\n", - " TASK 2.5:

\n", - "

\n", - " Be creative!\n", - "\n", - "Can you improve the results by using the data differently or by tweaking the settings?\n", - "\n", - "How could you train a single network to process all scan times?\n", - "

\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "12de7fb3", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "To train a network to process all scan times use this instead as the solution to Task 2.3:\n", - "The names \"low\" and \"GT\" don't really fit here anymore, so use names \"source_all\" and \"target_all\" instead" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "87183177", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "source_dir = \"data/SEM/train/source_all\"\n", - "target_dir = \"data/SEM/train/target_all\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "675708bd", - "metadata": {}, - "outputs": [], - "source": [ - "os.makedirs(source_dir, exist_ok=True)\n", - "os.makedirs(target_dir, exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "id": "87127c2e", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Since we wanna train a network for all scan times, we will use all images as our input images.\n", - "To train Noise2Noise we can use every other image as our target - as long as the noise is different the only remianing structure is the signal, so mixing different scan times is totally fine.\n", - "Images are paired by having the same name in `source_dir` and `target_dir`. This means we'll have several copies of the same image with different names. These images aren't very big, so that's fine." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9ff3e56", - "metadata": {}, - "outputs": [], - "source": [ - "counter = 0\n", - "for i in range(x_train.shape[0]):\n", - " for j in range(x_train.shape[0]):\n", - " if i == j:\n", - " continue\n", - " imwrite(os.path.join(source_dir, f\"{counter}.tif\"), x_train[i, ...])\n", - " imwrite(os.path.join(target_dir, f\"{counter}.tif\"), x_train[j, ...])\n", - " counter += 1" - ] - }, - { - "cell_type": "markdown", - "id": "fbf87638", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have reached the second checkpoint of this exercise! Please mark your progress in the course chat!\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/exercise3.ipynb b/exercise3.ipynb deleted file mode 100644 index 987499c..0000000 --- a/exercise3.ipynb +++ /dev/null @@ -1,642 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7d405b3a", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train a Noise2Void network\n", - "\n", - "Both the CARE network and Noise2Noise network you trained in part 1 and 2 require that you acquire additional data for the purpose of denoising. For CARE we used a paired acquisition with high SNR, for Noise2Noise we had paired noisy acquisitions.\n", - "We will now train a Noise2Void network from single noisy images.\n", - "\n", - "This notebook uses a single image from the SEM data from the Noise2Noise notebook, but as you'll see in Task 3.1 if you brought your own raw data you should adapt the notebook to use that instead.\n", - "\n", - "We now use the [Noise2Void library](https://github.com/juglab/n2v) instead of csbdeep/care, but don't worry - they're pretty similar.\n", - "\n", - "
\n", - "Set your python kernel to 03_image_restoration_part2\n", - "
\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
\n", - "\n", - "---\n", - "\n", - "

\n", - " TASK 3.1

\n", - "

\n", - "This notebook uses a single image from the SEM data from the Noise2Noise notebook.\n", - "\n", - "If you brought your own raw data, use that instead!\n", - "The only requirement is that the noise in your data is pixel-independent and zero-mean. If you're unsure whether your data fulfills that requirement or you don't yet understand why it is necessary ask one of us to discuss!\n", - "\n", - "If you don't have suitable data of your own, feel free to find some online or ask your fellow course participants. You can however also stick with the SEM data provided here and compare the results to what you achieved with Noise2Noise in the previous part.\n", - "

\n", - "
\n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f253352a", - "metadata": {}, - "outputs": [], - "source": [ - "# We import all our dependencies.\n", - "from n2v.models import N2VConfig, N2V\n", - "import numpy as np\n", - "from csbdeep.utils import plot_history\n", - "from n2v.utils.n2v_utils import manipulate_val_data\n", - "from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os\n", - "from skimage.metrics import structural_similarity, peak_signal_noise_ratio\n", - "from tifffile import imread\n", - "import zipfile\n", - "\n", - "%load_ext tensorboard\n", - "\n", - "import ssl\n", - "\n", - "ssl._create_default_https_context = ssl._create_unverified_context" - ] - }, - { - "cell_type": "markdown", - "id": "557ec582", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 1: Prepare data\n", - "Let's make sure the data is there!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ade8c11d", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "assert os.path.exists(\"data/SEM/train/train.tif\")\n", - "assert os.path.exists(\"data/SEM/test/test.tif\")" - ] - }, - { - "cell_type": "markdown", - "id": "0f458875", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We create a N2V_DataGenerator object to help load data and extract patches for training and validation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "569e0c45", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "datagen = N2V_DataGenerator()" - ] - }, - { - "cell_type": "markdown", - "id": "90ef2146", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The data generator provides two methods for loading data: `load_imgs_from_directory` and `load_imgs`. Let's look at their docstring to figure out how to use it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e752db3", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.load_imgs_from_directory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b68d54f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.load_imgs" - ] - }, - { - "cell_type": "markdown", - "id": "7cd57bd4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The SEM images are all in one directory, so we'll use `load_imgs_from_directory`. We'll pass in that directory (`\"data/SEM/train\"`), our image matches the default filter (`\"*.tif\"`) so we do not need to specify that. But our tif image is a stack of several images, so as dims we need to specify `\"TYX\"`.\n", - "If you're using your own data adapt this part to match your use case. If these functions aren't suitable for your use case load your images manually.\n", - "Feel free to ask a TA for help if you're unsure how to get your data loaded!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "289e03ce", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "imgs = datagen.load_imgs_from_directory(\"data/SEM/train\", dims=\"TYX\")\n", - "print(f\"Loaded {len(imgs)} images.\")\n", - "print(f\"First image has shape {imgs[0].shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "3bb5b8f5", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The method returned a list of images, as per the doc string the dimensions of each are \"SYXC\". However, we only want to use one of the images here since Noise2Void is designed to work with just one acquisition of the sample. Let's use the first image at $1\\mu s$ scantime." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02c80399", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "imgs = [img[2:3, :, :, :] for img in imgs]\n", - "print(f\"First image has shape {imgs[0].shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "18841f39", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "For generating patches the datagenerator provides the methods `generate_patches` and `generate_patches_from_list`. As before, let's have a quick look at the docstring" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8acfd6f1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.generate_patches" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad205d8f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.generate_patches_from_list" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1ffa91f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "type(imgs)" - ] - }, - { - "cell_type": "markdown", - "id": "4073063c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Our `imgs` object is a list, so `generate_patches_from_list` is the suitable function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd8ad59a", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "patches = datagen.generate_patches_from_list(imgs, shape=(96, 96))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf2fcfce", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# split into training and validation\n", - "n_train = int(round(0.9 * patches.shape[0]))\n", - "X, X_val = patches[:n_train, ...], patches[n_train:, ...]" - ] - }, - { - "cell_type": "markdown", - "id": "09ded741", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "As per usual, let's look at a training and validation patch to make sure everything looks okay." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5e2aa5a", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(14, 7))\n", - "plt.subplot(1, 2, 1)\n", - "plt.imshow(X[np.random.randint(X.shape[0]), ..., 0], cmap=\"gray_r\")\n", - "plt.title(\"Training patch\")\n", - "plt.subplot(1, 2, 2)\n", - "plt.imshow(X_val[np.random.randint(X_val.shape[0]), ..., 0], cmap=\"gray_r\")\n", - "plt.title(\"Validation patch\")" - ] - }, - { - "cell_type": "markdown", - "id": "9adf5aae", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 2: Configure and train the Noise2Void Network\n", - "\n", - "Noise2Void comes with a special config-object, where we store network-architecture and training specific parameters. See the docstring of the N2VConfig constructor for a description of all parameters.\n", - "\n", - "When creating the config-object, we provide the training data X. From X the library will extract mean and std that will be used to normalize all data before it is processed by the network.\n", - "\n", - "\n", - "Compared to supervised training (i.e. traditional CARE), we recommend to use N2V with an increased train_batch_size (e.g. 128) and batch_norm.\n", - "\n", - "To keep the network from learning the identity we have to manipulate the input pixels for the blindspot during training. How to exactly manipulate those values is controlled via the n2v_manipulator parameter with default value 'uniform_withCP' which samples a random value from the surrounding pixels, including the value at the control point. The size of the surrounding area can be configured via n2v_neighborhood_radius.\n", - "\n", - "The [paper supplement](https://arxiv.org/src/1811.10980v2/anc/supp_small.pdf) describes other pixel manipulators as well (section 3.1). If you want to configure one of those use the following values for n2v_manipulator:\n", - "* \"normal_additive\" for Gaussian (n2v_neighborhood_radius will set sigma)\n", - "* \"normal_fitted\" for Gaussian Fitting\n", - "* \"normal_withoutCP\" for Gaussian Pixel Selection\n", - "\n", - "For faster training multiple pixels per input patch can be manipulated. In our experiments we manipulated about 0.198% of the input pixels per patch. For a patch size of 64 by 64 pixels this corresponds to about 8 pixels. This fraction can be tuned via n2v_perc_pix.\n", - "\n", - "For Noise2Void training it is possible to pass arbitrarily large patches to the training method. From these patches random subpatches of size n2v_patch_shape are extracted during training. Default patch shape is set to (64, 64).\n", - "\n", - "In the past we experienced bleedthrough artifacts between channels if training was terminated to early. To counter bleedthrough we added the `single_net_per_channel` option, which is turned on by default. In the back a single U-Net for each channel is created and trained independently, thereby removing the possiblity of bleedthrough.
\n", - "Essentially the network gets multiplied by the number of channels, which increases the memory requirements. If your GPU gets too small, you can always split the channels manually and train a network for each channel one after another.\n", - "\n", - "---\n", - "

\n", - " TASK 3.2

\n", - "

\n", - "As suggested look at the docstring of the N2VConfig and then generate a configuration for your Noise2Void network, and choose a name to identify your model by.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ec86b39", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2VConfig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "993e8ac2", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "config = N2VConfig()\n", - "vars(config)\n", - "model_name = \"\"" - ] - }, - { - "cell_type": "markdown", - "id": "61203b93", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f41bab6", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# initialize the model\n", - "model = N2V(config, model_name, basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "35bf20d3", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Now let's train the model and monitor the progress in tensorboard.\n", - "Adapt the command below as you did before." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26940a2f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "%tensorboard --logdir=models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aaeb5c02", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "history = model.train(X, X_val)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a653ca4", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "print(sorted(list(history.history.keys())))\n", - "plt.figure(figsize=(16, 5))\n", - "plot_history(history, [\"loss\", \"val_loss\"])" - ] - }, - { - "cell_type": "markdown", - "id": "b520725d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 3: Prediction\n", - "\n", - "Similar to CARE a previously trained model is loaded by creating a new N2V-object without providing a `config`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a00aee95", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = N2V(config=None, name=model_name, basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "e1b5d86e", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's load a $1\\mu s$ scantime test images and denoise them using our network and like before we'll use the high SNR image to make a quantitative comparison. If you're using your own data and don't have an equivalent you can ignore that part." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c429e183", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "test_img = imread(\"data/SEM/test/test.tif\")[2, ...]\n", - "test_img_highSNR = imread(\"data/SEM/test/test.tif\")[-1, ...]\n", - "print(f\"Loaded test image with shape {test_img.shape}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28325ad3", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "test_denoised = model.predict(test_img, axes=\"YX\", n_tiles=(2, 1))" - ] - }, - { - "cell_type": "markdown", - "id": "84a3b87f", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's look at the results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a1ee796", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(30, 30))\n", - "plt.subplot(2, 3, 1)\n", - "plt.imshow(test_img, cmap=\"gray_r\")\n", - "plt.title(\"Noisy test image\")\n", - "plt.subplot(2, 3, 4)\n", - "plt.imshow(test_img[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.subplot(2, 3, 2)\n", - "plt.imshow(test_denoised, cmap=\"gray_r\")\n", - "plt.title(\"Denoised test image\")\n", - "plt.subplot(2, 3, 5)\n", - "plt.imshow(test_denoised[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.subplot(2, 3, 3)\n", - "plt.imshow(test_img_highSNR, cmap=\"gray_r\")\n", - "plt.title(\"High SNR image (4x5us)\")\n", - "plt.subplot(2, 3, 6)\n", - "plt.imshow(test_img_highSNR[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "561e5559", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "

\n", - " TASK 3.3

\n", - "

\n", - "\n", - "If you're using the SEM data (or happen to have a high SNR version of the image you predicted from) compare the structural similarity index and peak signal to noise ratio (wrt the high SNR image) of the noisy input image and the predicted image. If not, just skip this task.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a61bcb1f", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "ssi_input = ... # TODO\n", - "ssi_restored = ... # TODO\n", - "print(\n", - " f\"Structural similarity index (higher is better) wrt average of 4x5us images: \\n\"\n", - " f\"Input: {ssi_input} \\n\"\n", - " f\"Prediction: {ssi_restored}\"\n", - ")\n", - "psnr_input = ... # TODO\n", - "psnr_restored = ... # TODO\n", - "print(\n", - " f\"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\\n\"\n", - " f\"Input: {psnr_input} \\n\"\n", - " f\"Prediction: {psnr_restored}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "8e5e97cc", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have reached the third checkpoint of this exercise! Please mark your progress in the course chat!\n", - "

\n", - "

\n", - " Consider sharing some pictures of your results on element, especially if you used your own data.\n", - "

\n", - "

\n", - " If there's still time, check out the bonus exercise.\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/exercise_bonus1.ipynb b/exercise_bonus1.ipynb deleted file mode 100644 index f631876..0000000 --- a/exercise_bonus1.ipynb +++ /dev/null @@ -1,1099 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b1b7576d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train Probabilistic Noise2Void\n", - "\n", - "Probabilistic Noise2Void, just as N2V, allows training from single noisy images.\n", - "\n", - "In order to get some additional quality squeezed out of your noisy input data, PN2V employs an additional noise model which can either be measured directly at your microscope or approximated by a process called ‘bootstrapping’.\n", - "Below we will give you a noise model for the first network to train and then bootstrap one, so you can apply PN2V to your own data if you'd like.\n", - "\n", - "Note: The PN2V implementation is written in pytorch, not Keras/TF.\n", - "\n", - "Note: PN2V experienced multiple updates regarding noise model representations. Hence, the [original PN2V repository](https://github.com/juglab/pn2v) is not any more the one we suggest to use (despite it of course working just as described in the original publication). So here we use the [PPN2V repo](https://github.com/juglab/PPN2V) which you installed during setup.\n", - "\n", - "
\n", - "Set your python kernel to 03_image_restoration_bonus\n", - "
\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a56c4a75", - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "import torch\n", - "\n", - "dtype = torch.float\n", - "device = torch.device(\"cuda:0\")\n", - "from torch.distributions import normal\n", - "import matplotlib.pyplot as plt, numpy as np, pickle\n", - "from scipy.stats import norm\n", - "from tifffile import imread\n", - "import sys\n", - "import os\n", - "import urllib\n", - "import zipfile" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ce2cb17", - "metadata": {}, - "outputs": [], - "source": [ - "from ppn2v.pn2v import histNoiseModel, gaussianMixtureNoiseModel\n", - "from ppn2v.pn2v.utils import plotProbabilityDistribution, PSNR\n", - "from ppn2v.unet.model import UNet\n", - "from ppn2v.pn2v import training, prediction" - ] - }, - { - "cell_type": "markdown", - "id": "e8f8283c", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Data Preperation\n", - "\n", - "Here we use a dataset of 2D images of fluorescently labeled membranes of Convallaria (lilly of the valley) acquired with a spinning disk microscope.\n", - "All 100 recorded images (1024×1024 pixels) show the same region of interest and only differ in their noise." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f62d2875", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "# Check that data download was successful\n", - "assert os.path.exists(\"data/Convallaria_diaphragm\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c73978a", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"data/Convallaria_diaphragm/\"\n", - "data_name = \"convallaria\" # Name of the noise model\n", - "calibration_fn = \"20190726_tl_50um_500msec_wf_130EM_FD.tif\"\n", - "noisy_fn = \"20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif\"\n", - "noisy_imgs = imread(path + noisy_fn)\n", - "calibration_imgs = imread(path + calibration_fn)" - ] - }, - { - "cell_type": "markdown", - "id": "773f73ca", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "This notebook has a total of four options to generate a noise model for PN2V. You can pick which one you would like to use (and ignore the tasks in the options you don't wanna use)!\n", - "\n", - "There are two types of noise models for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM).\n", - "For both we need to provide a clean signal as groundtruth. For the dataset we have here we have calibration data available so you can choose between using the calibration data or bootstrapping the model by training a N2V network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78c9cfb5", - "metadata": {}, - "outputs": [], - "source": [ - "n_gaussian = 3 # Number of gaussians to use for Gaussian Mixture Model\n", - "n_coeff = 2 # No. of polynomial coefficients for parameterizing the mean, standard deviation and weight of Gaussian components." - ] - }, - { - "cell_type": "markdown", - "id": "dbfe7373", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Choice 1: Generate a Noise Model using Calibration Data\n", - "The noise model is a characteristic of your camera. The downloaded data folder contains a set of calibration images (For the Convallaria dataset, it is ```20190726_tl_50um_500msec_wf_130EM_FD.tif``` and the data to be denoised is named ```20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif```). We can either bin the noisy - GT pairs (obtained from noisy calibration images) as a 2-D histogram or fit a GMM distribution to obtain a smooth, parametric description of the noise model.\n", - "\n", - "We will use pairs of noisy calibration observations $x_i$ and clean signal $s_i$ (created by averaging these noisy, calibration images) to estimate the conditional distribution $p(x_i|s_i)$. Histogram-based and Gaussian Mixture Model-based noise models are generated and saved." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f08cf73", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "name_hist_noise_model_cal = \"_\".join([\"HistNoiseModel\", data_name, \"calibration\"])\n", - "name_gmm_noise_model_cal = \"_\".join(\n", - " [\"GMMNoiseModel\", data_name, str(n_gaussian), str(n_coeff), \"calibration\"]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b1b1ae65", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "

\n", - " TASK 4.1

\n", - "

\n", - "\n", - "The calibration data contains 100 images of a static sample. Estimate the clean signal by averaging all the images.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d828180c", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "# Average the images in `calibration_imgs`\n", - "signal_cal = ... # TODO" - ] - }, - { - "cell_type": "markdown", - "id": "96746d74", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's visualize a single image from the observation array alongside the average to see how the raw data compares to the pseudo ground truth signal." - ] - }, - { - "cell_type": "markdown", - "id": "2b50122c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d71a7778", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 12))\n", - "plt.subplot(1, 2, 2)\n", - "plt.title(label=\"average (ground truth)\")\n", - "plt.imshow(signal_cal[0], cmap=\"gray\")\n", - "plt.subplot(1, 2, 1)\n", - "plt.title(label=\"single raw image\")\n", - "plt.imshow(calibration_imgs[0], cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1456576b", - "metadata": {}, - "outputs": [], - "source": [ - "# The subsequent code expects the signal array to have a dimension for the samples\n", - "if signal_cal.shape == calibration_imgs.shape[1:]:\n", - " signal_cal = signal_cal[np.newaxis, ...]" - ] - }, - { - "cell_type": "markdown", - "id": "c690fc8f", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "There are two ways of generating a noise model for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM). You can pick which one you wanna use!\n", - "\n", - "
\n", - "\n", - "### Choice 1A: Creating the Histogram Noise Model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a histogram based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$.\n", - "\n", - "---\n", - "

\n", - " TASK 4.2

\n", - "

\n", - " Look at the docstring for createHistogram and use it to create a histogram based on the calibration data using the clean signal you created by averaging as groundtruth.

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbd00eb2", - "metadata": {}, - "outputs": [], - "source": [ - "?histNoiseModel.createHistogram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb78b79e", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = ... # TODO\n", - "max_val = ... # TODO\n", - "# Create the histogram\n", - "histogram_cal = histNoiseModel.createHistogram(bins, ...) # TODO" - ] - }, - { - "cell_type": "markdown", - "id": "5ea0dffb", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fc393b96", - "metadata": {}, - "outputs": [], - "source": [ - "# Saving histogram to disk.\n", - "np.save(path + name_hist_noise_model_cal + \".npy\", histogram_cal)\n", - "histogramFD_cal = histogram_cal[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f920fcf", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the histogram-based noise model.\n", - "plt.xlabel(\"Observation Bin\")\n", - "plt.ylabel(\"Signal Bin\")\n", - "plt.imshow(histogramFD_cal**0.25, cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5993f09c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Choice 1B: Creating the GMM noise model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "655c66f9", - "metadata": {}, - "outputs": [], - "source": [ - "min_signal = np.min(signal_cal)\n", - "max_signal = np.max(signal_cal)\n", - "print(\"Minimum Signal Intensity is\", min_signal)\n", - "print(\"Maximum Signal Intensity is\", max_signal)" - ] - }, - { - "cell_type": "markdown", - "id": "35722d03", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b056b9e6", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?gaussianMixtureNoiseModel.GaussianMixtureNoiseModel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ffb712e", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_cal = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel(\n", - " min_signal=min_signal,\n", - " max_signal=max_signal,\n", - " path=path,\n", - " weight=None,\n", - " n_gaussian=n_gaussian,\n", - " n_coeff=n_coeff,\n", - " min_sigma=50,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa8892fd", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_cal.train(\n", - " signal_cal,\n", - " calibration_imgs,\n", - " batchSize=250000,\n", - " n_epochs=2000,\n", - " learning_rate=0.1,\n", - " name=name_gmm_noise_model_cal,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7305eeb0", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Visualizing the Histogram-based and GMM-based noise models\n", - "\n", - "This only works if you generated both a histogram (Choice 1A) and GMM-based (Choice 1B) noise model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d060c437", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plotProbabilityDistribution(\n", - " signalBinIndex=170,\n", - " histogram=histogramFD_cal,\n", - " gaussianMixtureNoiseModel=gmm_noise_model_cal,\n", - " min_signal=min_val,\n", - " max_signal=max_val,\n", - " n_bin=bins,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e63e2061", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Choice 2: Generate a Noise Model by Bootstrapping\n", - "\n", - "Here we bootstrap a suitable histogram noise model and a GMM noise model after denoising the noisy images with Noise2Void and then using these denoised images as pseudo GT.\n", - "So first, we need to train a N2V model (now with pytorch) to estimate the conditional distribution $p(x_i|s_i)$. No additional calibration data is used for bootstrapping (so no need to use `calibration_imgs` or `singal_cal` again)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a4145cb", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = data_name + \"_n2v\"\n", - "name_hist_noise_model_bootstrap = \"_\".join([\"HistNoiseModel\", data_name, \"bootstrap\"])\n", - "name_gmm_noise_model_bootstrap = \"_\".join(\n", - " [\"GMMNoiseModel\", data_name, str(n_gaussian), str(n_coeff), \"bootstrap\"]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f076055e", - "metadata": {}, - "outputs": [], - "source": [ - "# Configure the Noise2Void network\n", - "n2v_net = UNet(1, depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d02b99c", - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare training+validation data\n", - "train_data = noisy_imgs[:-5].copy()\n", - "val_data = noisy_imgs[-5:].copy()\n", - "np.random.shuffle(train_data)\n", - "np.random.shuffle(val_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2dfc50a3", - "metadata": {}, - "outputs": [], - "source": [ - "train_history, val_history = training.trainNetwork(\n", - " net=n2v_net,\n", - " trainData=train_data,\n", - " valData=val_data,\n", - " postfix=model_name,\n", - " directory=path,\n", - " noiseModel=None,\n", - " device=device,\n", - " numOfEpochs=200,\n", - " stepsPerEpoch=10,\n", - " virtualBatchSize=20,\n", - " batchSize=1,\n", - " learningRate=1e-3,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e7261ec", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the training and validation loss\n", - "plt.xlabel(\"epoch\")\n", - "plt.ylabel(\"loss\")\n", - "plt.plot(val_history, label=\"validation loss\")\n", - "plt.plot(train_history, label=\"training loss\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb119445", - "metadata": {}, - "outputs": [], - "source": [ - "# We now run the N2V model to create pseudo groundtruth.\n", - "n2v_result_imgs = []\n", - "n2v_input_imgs = []\n", - "\n", - "for index in range(noisy_imgs.shape[0]):\n", - " im = noisy_imgs[index]\n", - " # We are using tiling to fit the image into memory\n", - " # If you get an error try a smaller patch size (ps)\n", - " n2v_pred = prediction.tiledPredict(\n", - " im, n2v_net, ps=256, overlap=48, device=device, noiseModel=None\n", - " )\n", - " n2v_result_imgs.append(n2v_pred)\n", - " n2v_input_imgs.append(im)\n", - " if index % 10 == 0:\n", - " print(\"image:\", index)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fff6264f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# In bootstrap mode, we estimate pseudo GT by using N2V denoised images.\n", - "signal_bootstrap = np.array(n2v_result_imgs)\n", - "# Let's look the raw data and our pseudo ground truth signal\n", - "print(signal_bootstrap.shape)\n", - "plt.figure(figsize=(12, 12))\n", - "plt.subplot(2, 2, 2)\n", - "plt.title(label=\"pseudo GT (generated by N2V denoising)\")\n", - "plt.imshow(signal_bootstrap[0], cmap=\"gray\")\n", - "plt.subplot(2, 2, 4)\n", - "plt.imshow(signal_bootstrap[0, -128:, -128:], cmap=\"gray\")\n", - "plt.subplot(2, 2, 1)\n", - "plt.title(label=\"single raw image\")\n", - "plt.imshow(noisy_imgs[0], cmap=\"gray\")\n", - "plt.subplot(2, 2, 3)\n", - "plt.imshow(noisy_imgs[0, -128:, -128:], cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fd230f12", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Now that we have pseudoGT, you can pick again between a histogram based noise model and a GMM noise model\n", - "\n", - "
\n", - "\n", - "### Choice 2A: Creating the Histogram Noise Model\n", - "\n", - "---\n", - "

\n", - " TASK 4.3

\n", - "

\n", - " If you've already done Task 4.2, this is very similar!\n", - " Look at the docstring for createHistogram and use it to create a histogram using the bootstraped signal you created from the N2V predictions.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88a4cbe7", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?histNoiseModel.createHistogram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09b7ca76", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = ... # TODO\n", - "max_val = ... # TODO\n", - "# Create the histogram\n", - "histogram_bootstrap = histNoiseModel.createHistogram(bins, ...) # TODO" - ] - }, - { - "cell_type": "markdown", - "id": "69aff158", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad8e6df1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Saving histogram to disk.\n", - "np.save(path + name_hist_noise_model_bootstrap + \".npy\", histogram_bootstrap)\n", - "histogramFD_bootstrap = histogram_bootstrap[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5ade612", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the histogram-based noise model\n", - "plt.xlabel(\"Observation Bin\")\n", - "plt.ylabel(\"Signal Bin\")\n", - "plt.imshow(histogramFD_bootstrap**0.25, cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "f6074610", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Choice 2B: Creating the GMM noise model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "57f33040", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "min_signal = np.percentile(signal_bootstrap, 0.5)\n", - "max_signal = np.percentile(signal_bootstrap, 99.5)\n", - "print(\"Minimum Signal Intensity is\", min_signal)\n", - "print(\"Maximum Signal Intensity is\", max_signal)" - ] - }, - { - "cell_type": "markdown", - "id": "d775b9a4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43a50b02", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_bootstrap = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel(\n", - " min_signal=min_signal,\n", - " max_signal=max_signal,\n", - " path=path,\n", - " weight=None,\n", - " n_gaussian=n_gaussian,\n", - " n_coeff=n_coeff,\n", - " device=device,\n", - " min_sigma=50,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4611b54b", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_bootstrap.train(\n", - " signal_bootstrap,\n", - " noisy_imgs,\n", - " batchSize=250000,\n", - " n_epochs=2000,\n", - " learning_rate=0.1,\n", - " name=name_gmm_noise_model_bootstrap,\n", - " lowerClip=0.5,\n", - " upperClip=99.5,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "aaa3f882", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Visualizing the Histogram-based and GMM-based noise models\n", - "\n", - "This only works if you generated both a histogram (Choice 2A) and GMM-based (Choice 2B) noise model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "993c6b8e", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plotProbabilityDistribution(\n", - " signalBinIndex=170,\n", - " histogram=histogramFD_bootstrap,\n", - " gaussianMixtureNoiseModel=gmm_noise_model_bootstrap,\n", - " min_signal=min_val,\n", - " max_signal=max_val,\n", - " n_bin=bins,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "89f86336", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## PN2V Training\n", - "\n", - "---\n", - "

\n", - " TASK 4.4

\n", - "

\n", - " Adapt to use the noise model of your choice here to then train PN2V with.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0dffc131", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "noise_model_type = \"gmm\" # pick: \"hist\" or \"gmm\"\n", - "noise_model_data = \"bootstrap\" # pick: \"calibration\" or \"bootstrap\"" - ] - }, - { - "cell_type": "markdown", - "id": "6bc7c3e9", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4fa867d1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Create a network with 800 output channels that are interpreted as samples from the prior.\n", - "pn2v_net = UNet(800, depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43d6e350", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Start training.\n", - "trainHist, valHist = training.trainNetwork(\n", - " net=pn2v_net,\n", - " trainData=train_data,\n", - " valData=val_data,\n", - " postfix=noise_model_name,\n", - " directory=path,\n", - " noiseModel=noise_model,\n", - " device=device,\n", - " numOfEpochs=200,\n", - " stepsPerEpoch=5,\n", - " virtualBatchSize=20,\n", - " batchSize=1,\n", - " learningRate=1e-3,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "57b92b13", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## PN2V Evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ae0bb6d", - "metadata": {}, - "outputs": [], - "source": [ - "test_data = noisy_imgs[\n", - " :, :512, :512\n", - "] # We are loading only a sub image to speed up computation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d074aee5", - "metadata": {}, - "outputs": [], - "source": [ - "# We estimate the ground truth by averaging.\n", - "test_data_gt = np.mean(test_data[:, ...], axis=0)[np.newaxis, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6225e3d3", - "metadata": {}, - "outputs": [], - "source": [ - "pn2v_net = torch.load(path + \"/last_\" + noise_model_name + \".net\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb12628c", - "metadata": {}, - "outputs": [], - "source": [ - "# Now we are processing data and calculating PSNR values.\n", - "mmse_psnrs = []\n", - "prior_psnrs = []\n", - "input_psnrs = []\n", - "result_ims = []\n", - "input_ims = []\n", - "\n", - "# We iterate over all test images.\n", - "for index in range(test_data.shape[0]):\n", - " im = test_data[index]\n", - " gt = test_data_gt[0] # The ground truth is the same for all images\n", - "\n", - " # We are using tiling to fit the image into memory\n", - " # If you get an error try a smaller patch size (ps)\n", - " means, mse_est = prediction.tiledPredict(\n", - " im, pn2v_net, ps=192, overlap=48, device=device, noiseModel=noise_model\n", - " )\n", - "\n", - " result_ims.append(mse_est)\n", - " input_ims.append(im)\n", - "\n", - " range_psnr = np.max(gt) - np.min(gt)\n", - " psnr = PSNR(gt, mse_est, range_psnr)\n", - " psnr_prior = PSNR(gt, means, range_psnr)\n", - " input_psnr = PSNR(gt, im, range_psnr)\n", - " mmse_psnrs.append(psnr)\n", - " prior_psnrs.append(psnr_prior)\n", - " input_psnrs.append(input_psnr)\n", - "\n", - " print(\"image:\", index)\n", - " print(\"PSNR input\", input_psnr)\n", - " print(\"PSNR prior\", psnr_prior) # Without info from masked pixel\n", - " print(\"PSNR mse\", psnr) # MMSE estimate using the masked pixel\n", - " print(\"-----------------------------------\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "69438c2a", - "metadata": {}, - "outputs": [], - "source": [ - "?prediction.tiledPredict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9c27130", - "metadata": {}, - "outputs": [], - "source": [ - "# We display the results for the last test image\n", - "vmi = np.percentile(gt, 0.01)\n", - "vma = np.percentile(gt, 99)\n", - "\n", - "plt.figure(figsize=(15, 15))\n", - "plt.subplot(1, 3, 1)\n", - "plt.title(label=\"Input Image\")\n", - "plt.imshow(im, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "\n", - "plt.subplot(1, 3, 2)\n", - "plt.title(label=\"Avg. Prior\")\n", - "plt.imshow(means, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "\n", - "plt.subplot(1, 3, 3)\n", - "plt.title(label=\"PN2V-MMSE estimate\")\n", - "plt.imshow(mse_est, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.show()\n", - "\n", - "plt.figure(figsize=(15, 15))\n", - "plt.subplot(1, 3, 1)\n", - "plt.title(label=\"Input Image\")\n", - "plt.imshow(im[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "plt.subplot(1, 3, 2)\n", - "plt.title(label=\"Avg. Prior\")\n", - "plt.imshow(means[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "plt.subplot(1, 3, 3)\n", - "plt.title(label=\"PN2V-MMSE estimate\")\n", - "plt.imshow(mse_est[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "\n", - "plt.figure(figsize=(15, 5))\n", - "plt.plot(im[150, 150:250], label=\"Input Image\")\n", - "plt.plot(means[150, 150:250], label=\"Avg. Prior\")\n", - "plt.plot(mse_est[150, 150:250], label=\"PN2V-MMSE estimate\")\n", - "plt.plot(gt[150, 150:250], label=\"Pseudo GT by averaging\")\n", - "plt.legend()\n", - "\n", - "plt.show()\n", - "print(\n", - " \"Avg PSNR Prior:\",\n", - " np.mean(np.array(prior_psnrs)),\n", - " \"+-(2SEM)\",\n", - " 2 * np.std(np.array(prior_psnrs)) / np.sqrt(float(len(prior_psnrs))),\n", - ")\n", - "print(\n", - " \"Avg PSNR MMSE:\",\n", - " np.mean(np.array(mmse_psnrs)),\n", - " \"+-(2SEM)\",\n", - " 2 * np.std(np.array(mmse_psnrs)) / np.sqrt(float(len(mmse_psnrs))),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "66930ec5", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "---\n", - "

\n", - " TASK 4.5

\n", - "

\n", - " Try PN2V for your own data! You probably don't have calibration data, but with the bootstrapping method you don't need any!\n", - "

\n", - "
\n", - "\n", - "---\n", - "\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have completed the bonus exercise!\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/exercise_bonus2.md b/exercise_bonus2.md deleted file mode 100644 index f87826b..0000000 --- a/exercise_bonus2.md +++ /dev/null @@ -1,9 +0,0 @@ -# Second Bonus Exercise: DivNoising - -[DivNoising](https://openreview.net/pdf?id=agHLCOBM5jP) is one of the latest unsupervised denoising methods and follows a somewhat different approach. Instead of a U-Net, DivNoising employs the power of a Variational Auto-Encoder (VAE), but adds the use of a noise model to it in a suitable way. - -A nice perk of this approach is that you will be able to sample diverse interpretations of a noisy image. Why is that useful? If the diverse samples look all the same or very similar to each other you can infer that the data is not very ambigious and you might decide to trust the result more. If on the other hand, the samples look quite different you know that you might not want to trust any of the denoised "interpretations". In the DivNoising paper you can also see how the diverse samples can be used in meaningful ways for imporved downstream analysis. - -Since you've made it this far, you're clearly a pro so we will now take off the training wheels, essentially putting you in the position you would find yourself in when you come across a method you find interesting and want to check it out. - -Clone the Div Noising repository from here: https://github.com/juglab/DivNoising, follow the setup instructions there and run through the Convallaria example. diff --git a/nb_material/notebook_shutdown.png b/nb_material/notebook_shutdown.png deleted file mode 100644 index ecaf8c5..0000000 Binary files a/nb_material/notebook_shutdown.png and /dev/null differ diff --git a/pyscripts/convert-solution.py b/pyscripts/convert-solution.py deleted file mode 100644 index 279f787..0000000 --- a/pyscripts/convert-solution.py +++ /dev/null @@ -1,41 +0,0 @@ -import argparse -from traitlets.config import Config -import nbformat as nbf -from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor -from nbconvert.exporters import NotebookExporter - - -def get_arg_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument('input_file') - parser.add_argument('output_file') - - return parser - - -def convert(input_file, output_file): - c = Config() - c.TagRemovePreprocessor.remove_cell_tags = ("solution",) - c.TagRemovePreprocessor.enabled = True - c.ClearOutputPreprocesser.enabled = True - c.NotebookExporter.preprocessors = [ - "nbconvert.preprocessors.TagRemovePreprocessor", - "nbconvert.preprocessors.ClearOutputPreprocessor" - ] - - exporter = NotebookExporter(config=c) - exporter.register_preprocessor(TagRemovePreprocessor(config=c), True) - exporter.register_preprocessor(ClearOutputPreprocessor(), True) - - output = NotebookExporter(config=c).from_filename(input_file) - with open(output_file, 'w') as f: - f.write(output[0]) - - -if __name__ == "__main__": - parser = get_arg_parser() - args = parser.parse_args() - - convert(args.input_file, args.output_file) - print(f'Converted {args.input_file} to {args.output_file}') diff --git a/pyscripts/exercise1.py b/pyscripts/exercise1.py deleted file mode 100644 index 33097bc..0000000 --- a/pyscripts/exercise1.py +++ /dev/null @@ -1,353 +0,0 @@ -# %% [markdown] -""" -
- -# Train your first CARE model (supervised) - -In this first example we will train a CARE model for a 2D denoising and upsampling task, where corresponding pairs of low and high signal-to-noise ratio (SNR) images of cells are available. Here the high SNR images are acquisitions of Human U2OS cells taken from the [Broad Bioimage Benchmark Collection](https://data.broadinstitute.org/bbbc/BBBC006/) and the low SNR images were created by synthetically adding *strong read-out and shot-noise* and applying *pixel binning* of 2x2, thus mimicking acquisitions at a very low light level. - -![](nb_material/denoising_binning_overview.png) - - -For CARE, image pairs should be registered, which in practice is best achieved by acquiring both stacks _interleaved_, i.e. as different channels that correspond to the different exposure/laser settings. - -Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so typically need to be co-registered before training a CARE model. - -To train a denoising network, we will use the [CSBDeep Repo](https://github.com/CSBDeep/CSBDeep). This notebook has a very similar structure to the examples you can find there. -More documentation is available at http://csbdeep.bioimagecomputing.com/doc/. - -This part will not have any coding tasks, but go through each cell and try to understand what's going on - it will help you in the next part! We also put some questions along the way. For some of them you might need to dig a bit deeper. - -
-Set your python kernel to 03_image_restoration_part1 -
-""" - -# %% -from __future__ import absolute_import, division, print_function, unicode_literals - -import matplotlib.pyplot as plt -import os -import numpy as np -from csbdeep.data import ( - RawData, - create_patches, - no_background_patches, - norm_percentiles, - sample_percentiles, -) -from csbdeep.io import load_training_data, save_tiff_imagej_compatible -from csbdeep.models import CARE, Config -from csbdeep.utils import ( - Path, - axes_dict, - normalize, - plot_history, - plot_some, -) -from csbdeep.utils.tf import limit_gpu_memory - -# %matplotlib inline -# %load_ext tensorboard -# %config InlineBackend.figure_format = 'retina' -from tifffile import imread - -# %% [markdown] -""" -
- -## Part 1: Training Data Generation -Network training usually happens on batches of smaller sized images than the ones recorded on a microscopy. In this first part of the exercise, we will load all of the image data and chop it into smaller pieces, a.k.a. patches. - -### Look at example data - -During setup, we downloaded some example data, consisting of low-SNR and high-SNR 3D images of Tribolium. -Note that `GT` stands for ground truth and represents high signal-to-noise ratio (SNR) stacks. -""" -# %% -assert os.path.exists("data/U2OS") -# %% [markdown] -""" -As we can see, the data set is already split into a **train** and **test** set, each containing (synthetically generated) low SNR ("low") and corresponding high SNR ("GT") images. - -Let's look at an example pair of training images: -""" -# %% -y = imread("data/U2OS/train/GT/img_0010.tif") -x = imread("data/U2OS/train/low/img_0010.tif") -print("GT image size =", x.shape) -print("low-SNR image size =", y.shape) - -# %% -plt.figure(figsize=(13, 5)) -plt.subplot(1, 2, 1) -plt.imshow(x, cmap="magma") -plt.colorbar() -plt.title("low") -plt.subplot(1, 2, 2) -plt.imshow(y, cmap="magma") -plt.colorbar() -plt.title("high") -plt.show() -# %% [markdown] -""" -### Generate training data for CARE - -We first need to create a `RawData` object, which defines how to get the pairs of low/high SNR stacks and the semantics of each axis (e.g. which one is considered a color channel, etc.). In general the names for the axes are: - -X: columns, Y: rows, Z: planes, C: channels, T: frames/time, (S: samples/images) - -Here we have two folders "low" and "GT", where corresponding low and high-SNR stacks are TIFF images with identical filenames. - -For this case, we can simply use `RawData.from_folder` and set `axes = 'YX'` to indicate the semantic order of the image axes, i.e. we have two-dimensional images in standard xy layout. -""" -# %% -raw_data = RawData.from_folder( - basepath="data/U2OS/train", - source_dirs=["low"], - target_dir="GT", - axes="YX", -) -# %% [markdown] -""" -From corresponding images, we now generate some 2D patches to use for training. - -As a general rule, use a *patch size* that is a power of two along all axes, or at least divisible by 8. Typically, you should use more patches the more trainings images you have. - -An important aspect is *data normalization*, i.e. the rescaling of corresponding patches to a dynamic range of ~ (0,1). By default, this is automatically provided via percentile normalization, which can be adapted if needed. - -By default, patches are sampled from *non-background regions* (i.e. that are above a relative threshold). We will disable this for the current example as most image regions already contain foreground pixels and thus set the threshold to 0. See the documentation of `create_patches` for details. - -Note that returned values `(X, Y, XY_axes)` by `create_patches` are not to be confused with the image axes X and Y. By convention, the variable name X (or x) refers to an input variable for a machine learning model, whereas Y (or y) indicates an output variable. -""" -# %% -X, Y, XY_axes = create_patches( - raw_data=raw_data, - patch_size=(128, 128), - patch_filter=no_background_patches(0), - n_patches_per_image=2, - save_file="data/U2OS/my_training_data.npz", -) - -assert X.shape == Y.shape -print("shape of X,Y =", X.shape) -print("axes of X,Y =", XY_axes) - -# %% [markdown] -""" -### Show - -This shows some of the generated patch pairs (odd rows: *source*, even rows: *target*). -""" -# %% -for i in range(2): - plt.figure(figsize=(16, 4)) - sl = slice(8 * i, 8 * (i + 1)), 0 - plot_some( - X[sl], Y[sl], title_list=[np.arange(sl[0].start, sl[0].stop)] - ) # convenience function provided by CSB Deep - plt.show() -# %% [markdown] -""" -

- Questions:

-
    -
  1. Where is the training data located?
  2. -
  3. How is the data organized to identify the pairs of HR and LR images?
  4. -
-
- -
- -## Part 2: Training the network - - -### Load Training data - -Load the patches generated in part 1, use 10% as validation data. -""" -# %% -(X, Y), (X_val, Y_val), axes = load_training_data( - "data/U2OS/my_training_data.npz", validation_split=0.1, verbose=True -) - -c = axes_dict(axes)["C"] -n_channel_in, n_channel_out = X.shape[c], Y.shape[c] - -# %% -plt.figure(figsize=(12, 5)) -plot_some(X_val[:5], Y_val[:5]) -plt.suptitle("5 example validation patches (top row: source, bottom row: target)") -plt.show() - -# %% [markdown] -""" -### Configure the CARE model -Before we construct the actual CARE model, we have to define its configuration via a `Config` object, which includes -* parameters of the underlying neural network, -* the learning rate, -* the number of parameter updates per epoch, -* the loss function, and -* whether the model is probabilistic or not. - -![](nb_material/carenet.png) - -The defaults should be sensible in many cases, so a change should only be necessary if the training process fails. - -Important: Note that for this notebook we use a very small number of update steps for immediate feedback, whereas the number of epochs and steps per epoch should be increased considerably (e.g. `train_steps_per_epoch=400`, `train_epochs=100`) to obtain a well-trained model. -""" -# %% -config = Config( - axes, - n_channel_in, - n_channel_out, - train_batch_size=8, - train_steps_per_epoch=40, - train_epochs=20, -) -vars(config) -# %% [markdown] -""" -We now create a CARE model with the chosen configuration: -""" -# %% -model = CARE(config, "my_CARE_model", basedir="models") -# %% [markdown] -""" -We can get a summary of all the layers in the model and the number of parameters: -""" -# %% -model.keras_model.summary() -# %% [markdown] -""" -### Training - -Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training. -Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on. - -We can start tensorboard within the notebook. - -Alternatively, you can launch the notebook in an independent tab by changing the `%` to `!` -
-If you're using ssh add --host <hostname> to the command: -! tensorboard --logdir models --host <hostname> where <hostname> is the thing that ends in amazonaws.com. -
-""" -# %% -# %tensorboard --logdir models -# %% -history = model.train(X, Y, validation_data=(X_val, Y_val)) -# %% [markdown] -""" -Plot final training history (available in TensorBoard during training): -""" -# %% -print(sorted(list(history.history.keys()))) -plt.figure(figsize=(16, 5)) -plot_history(history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) - -# %% [markdown] -""" -### Evaluation -Example results for validation images. -""" -# %% -plt.figure(figsize=(12, 7)) -_P = model.keras_model.predict(X_val[:5]) -if config.probabilistic: - _P = _P[..., : (_P.shape[-1] // 2)] -plot_some(X_val[:5], Y_val[:5], _P, pmax=99.5) -plt.suptitle( - "5 example validation patches\n" - "top row: input (source), " - "middle row: target (ground truth), " - "bottom row: predicted from source" -) -# %% [markdown] -""" -

- Questions:

-
    -
  1. Where are trained models stored? What models are being stored, how do they differ?
  2. -
  3. How does the name of the saved models get specified?
  4. -
  5. How can you influence the number of training steps per epoch? What did you use?
  6. -
-
- -
- -## Part 3: Prediction - -Plot the test stack pair and define its image axes, which will be needed later for CARE prediction. -""" -# %% -y_test = imread("data/U2OS/test/GT/img_0010.tif") -x_test = imread("data/U2OS/test/low/img_0010.tif") - -axes = "YX" -print("image size =", x_test.shape) -print("image axes =", axes) - -plt.figure(figsize=(16, 10)) -plot_some(np.stack([x_test, y_test]), title_list=[["low", "high"]]) - - -# %% [markdown] -""" -### Load CARE model - -Load trained model (located in base directory `models` with name `my_CARE_model`) from disk. -The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`. -""" -# %% -model = CARE(config=None, name="my_CARE_model", basedir="models") -# %% [markdown] -""" -### Apply CARE network to raw image -Predict the restored image (image will be successively split into smaller tiles if there are memory issues). -""" -# %% -# %%time -restored = model.predict(x_test, axes) - -# %% [markdown] -""" -### Save restored image - -Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics. -""" -# %% -Path("results").mkdir(exist_ok=True) -save_tiff_imagej_compatible("results/%s_img_0010.tif" % model.name, restored, axes) - -# %% [markdown] -""" -### Visualize results -Plot the test stack pair and the predicted restored stack (middle). -""" - -# %% -plt.figure(figsize=(15, 10)) -plot_some( - np.stack([x_test, restored, y_test]), - title_list=[["low", "CARE", "GT"]], - pmin=2, - pmax=99.8, -) - -plt.figure(figsize=(10, 5)) -for _x, _name in zip((x_test, restored, y_test), ("low", "CARE", "GT")): - plt.plot(normalize(_x, 1, 99.7)[180], label=_name, lw=2) -plt.legend() -plt.show() - -# %% [markdown] -""" -
-

- Congratulations!

-

- You have reached the first checkpoint of this exercise! Please mark your progress in the course chat! -

-
-""" diff --git a/pyscripts/prepare-exercise.sh b/pyscripts/prepare-exercise.sh deleted file mode 100644 index ce4f7bd..0000000 --- a/pyscripts/prepare-exercise.sh +++ /dev/null @@ -1,17 +0,0 @@ -# Run black on .py files -black exercise1.py solution2.py solution3.py solution_bonus1.py - - -# Convert .py to ipynb -# "cell_metadata_filter": "all" preserve cell tags including our solution tags -jupytext --to ipynb --update-metadata '{"jupytext": {"cell_metadata_filter":"all"}}' exercise1.py --output ../exercise1.ipynb -jupytext --to ipynb --update-metadata '{"jupytext": {"cell_metadata_filter":"all"}}' solution2.py --output ../solution2.ipynb -jupytext --to ipynb --update-metadata '{"jupytext": {"cell_metadata_filter":"all"}}' solution3.py --output ../solution3.ipynb -jupytext --to ipynb --update-metadata '{"jupytext": {"cell_metadata_filter":"all"}}' solution_bonus1.py --output ../solution_bonus1.ipynb - -# Create the exercise notebook by removing cell outputs and deleting cells tagged with "solution" -# There is a bug in the nbconvert cli so we need to use the python API instead -python convert-solution.py ../solution2.ipynb ../exercise2.ipynb -python convert-solution.py ../solution3.ipynb ../exercise3.ipynb -python convert-solution.py ../solution_bonus1.ipynb ../exercise_bonus1.ipynb - diff --git a/pyscripts/solution2.py b/pyscripts/solution2.py deleted file mode 100644 index e0e78b8..0000000 --- a/pyscripts/solution2.py +++ /dev/null @@ -1,475 +0,0 @@ -# %% [markdown] -""" -
- -# Train a Noise2Noise network with CARE -
-Set your python kernel to 03_image_restoration_part1! That's the same as for the first notebook. -
- -We will now train a 2D Noise2Noise network using CARE. We will closely follow along the previous example but now you will have to fill in some parts on your own! -You will have to make decisions - make them! - -But first some clean up... -
-Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems. -
- -![](nb_material/notebook_shutdown.png) -""" -# %% -from __future__ import absolute_import, division, print_function, unicode_literals - -import gc -import os - -import matplotlib.pyplot as plt -import numpy as np -from csbdeep.data import RawData, create_patches -from csbdeep.io import load_training_data, save_tiff_imagej_compatible -from csbdeep.models import CARE, Config -from csbdeep.utils import ( - Path, - axes_dict, - plot_history, - plot_some, -) -from csbdeep.utils.tf import limit_gpu_memory - -# %matplotlib inline -# %load_ext tensorboard -# %config InlineBackend.figure_format = 'retina' -from skimage.metrics import peak_signal_noise_ratio, structural_similarity -from tifffile import imread, imwrite - -# %% [markdown] -""" -
- -## Part 1: Training Data Generation - -### Download example data - -To train a Noise2Noise setup we need several acquisitions of the same sample. -The SEM data we downloaded during setup contains 2 tiff-stacks, one for training and one for testing, let's make sure it's there! -""" -# %% -assert os.path.exists("data/SEM/train/train.tif") -assert os.path.exists("data/SEM/test/test.tif") - -# %% [markdown] -# Let's have a look at the data! -# Each image is a tiff stack containing 7 images of the same tissue recorded with different scan time settings of a Scanning Electron Miscroscope (SEM). The faster a SEM image is scanned, the noisier it gets. - -# %% -imgs = imread("data/SEM/train/train.tif") -x_size = imgs.shape -print("image size =", x_size) -scantimes_all = ["0.2us", "0.5us", "1us", "1us", "2.1us", "5us", "5us, avg of 4"] -plt.figure(figsize=(40, 16)) -plot_some(imgs, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap="gray_r") - -# %% [markdown] -# --- -#

-# TASK 2.1:

-#

-# The noise level is hard to see at this zoom level. Let's also look at a smaller crop of them! Play around with this until you have a feeling for what the data looks like. -#

-#
- -# %% -###TODO### - -imgs_cropped = ... # TODO -# %% tags=["solution"] -imgs_cropped = imgs[:, 1000:1128, 600:728] -# %% -plt.figure(figsize=(40, 16)) -plot_some(imgs_cropped, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap="gray_r") - -# %% [markdown] -""" ---- -""" -# %% -# checking that you didn't crop x_train itself, we still need that! -assert imgs.shape == x_size - -# %% [markdown] -""" -As you can see the last image, which is the average of 4 images with 5$\mu s$ scantime, has the highest signal-to-noise-ratio. It is not noise-free but our best choice to be able to compare our results against quantitatively, so we will set it aside for that purpose. -""" -# %% -scantimes, scantime_highSNR = scantimes_all[:-1], scantimes_all[-1] -x_train, x_highSNR = imgs[:-1], imgs[-1] -print(scantimes, scantime_highSNR) -print(x_train.shape, x_highSNR.shape) - -# %% [markdown] -""" -### Generate training data for CARE - -Let's try and train a network to denoise images of $1 \mu s$ scan time! -Which images do you think could be used as input and which as target? - ---- -

- TASK 2.2:

-

- Decide which images to use as inputs and which as targets. Then, remember from part one how the data has to be organized to match up inputs and targets. -

-
-""" -# %% -###TODO### -base_path = "data/SEM/train" -source_dir = os.path.join(base_path, "") # pick path in which to save inputs -target_dir = os.path.join(base_path, "") # pick path in which to save targets -# %% tags=["solution"] -# The names "low" and "GT" don't really fit here anymore, so use names "source" and "target" instead -base_path = "data/SEM/train" -source_dir = os.path.join(base_path, "source_1us") -target_dir = os.path.join(base_path, "target_1us") - -# %% -os.makedirs(source_dir, exist_ok=True) -os.makedirs(target_dir, exist_ok=True) - -# %% -# Now save individual images into these directories -# You can use the imwrite function to save images. The ? command will pull up the docstring -# ?imwrite -# %% [markdown] -""" -Hint: The tiff file you read earlier contained 7 images for the different instances. Here, use a single tiff file per image. -""" -# %% [markdown] -""" -Hint: Remember we're trying to train a Noise2Noise network here, so the target does not need to be clean. -""" -# %% -###TODO### - -# Put the pairs of input and target images into the `source_dir` and `target_dir`, respectively. -# The goal here is to the train a network for 1 us scan time. - -# %% tags = ["solution"] -# Since we wanna train a network for images of 1us scan time, we will use the two images as our input images. -# For both of these images we can use every other image as our target - as long as the noise is different the -# only remaining structure is the signal, so mixing different scan times is totally fine. -# Images are paired by having the same name in `source_dir` and `target_dir`. This means we'll have several -# copies of the same image with different names. These images aren't very big, so that's fine. -counter = 0 -for i in range(2, 4): - for j in range(x_train.shape[0]): - if i == j: - continue - imwrite(os.path.join(source_dir, f"{counter}.tif"), x_train[i, ...]) - imwrite(os.path.join(target_dir, f"{counter}.tif"), x_train[j, ...]) - counter += 1 -# %% [markdown] -""" ---- ---- -

- TASK 2.3:

-

- Now that you arranged the training data we can now create the raw data object. -

-
-""" - -# %% -###TODO### -raw_data = RawData.from_folder( - basepath="data/SEM/train", - source_dirs=[""], # fill in your directory for source images - target_dir="", # fill in your directory of target images - axes="", # what should the axes tag be? -) - -# %% tags=["solution"] -raw_data = RawData.from_folder( - basepath="data/SEM/train", - source_dirs=["source_1us"], # fill in your directory for source images - target_dir="target_1us", # fill in your directory of target images - axes="YX", # what should the axes tag be? -) -# %% [markdown] -""" ---- -We generate 2D patches. If you'd like, you can play around with the parameters here. -""" -# %% -X, Y, XY_axes = create_patches( - raw_data=raw_data, - patch_size=(256, 256), - n_patches_per_image=512, - save_file="data/SEM/my_1us_training_data.npz", -) - -assert X.shape == Y.shape -print("shape of X,Y =", X.shape) -print("axes of X,Y =", XY_axes) - -# %% [markdown] -""" -### Show - -Let's look at some of the generated patch pairs. (odd rows: _source_, even rows: _target_) -""" -# %% -for i in range(2): - plt.figure(figsize=(16, 4)) - sl = slice(8 * i, 8 * (i + 1)), 0 - plot_some( - X[sl], Y[sl], title_list=[np.arange(sl[0].start, sl[0].stop)], cmap="gray_r" - ) -plt.show() - - -# %% [markdown] -""" -
- -## Part 2: Training the network - - -### Load Training data - -Load the patches generated in part 1, use 10% as validation data. -""" -# %% -(X, Y), (X_val, Y_val), axes = load_training_data( - "data/SEM/my_1us_training_data.npz", validation_split=0.1, verbose=True -) - -c = axes_dict(axes)["C"] -n_channel_in, n_channel_out = X.shape[c], Y.shape[c] - - -plt.figure(figsize=(12, 5)) -plot_some(X_val[:5], Y_val[:5], cmap="gray_r", pmin=0.2, pmax=99.8) -plt.suptitle("5 example validation patches (top row: source, bottom row: target)") - -config = Config( - axes, n_channel_in, n_channel_out, train_steps_per_epoch=10, train_epochs=100 -) -vars(config) - -# %% [markdown] -""" -We now create a CARE model with the chosen configuration: -""" -# %% -model = CARE(config, "my_N2N_model", basedir="models") - -# %% [markdown] -""" -### Training - -Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training. -Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on. - -Start tensorboard as you did in the previous notebook. -""" -# %% -# %tensorboard --logdir models -# %% -history = model.train(X, Y, validation_data=(X_val, Y_val)) - -# %% [markdown] -# Plot final training history (available in TensorBoard during training): - -# %% -print(sorted(list(history.history.keys()))) -plt.figure(figsize=(16, 5)) -plot_history(history, ["loss", "val_loss"], ["mse", "val_mse", "mae", "val_mae"]) - -# %% [markdown] -""" -### Evaluation -Example results for validation images. -""" -# %% -plt.figure(figsize=(12, 7)) -_P = model.keras_model.predict(X_val[:5]) -if config.probabilistic: - _P = _P[..., : (_P.shape[-1] // 2)] -plot_some(X_val[:5], Y_val[:5], _P, pmin=0.2, pmax=99.8, cmap="gray_r") -plt.suptitle( - "5 example validation patches\n" - "top row: input (noisy source), " - "mid row: target (independently noisy), " - "bottom row: predicted from source, " -) - -# %% [markdown] -""" -
- -## Part 3: Prediction - - -### Load CARE model - -Load trained model (located in base directory `models` with name `my_model`) from disk. -The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`. -""" -# %% -model = CARE(config=None, name="my_N2N_model", basedir="models") -# %% [markdown] -""" -### Apply CARE network to raw image -Now use the trained model to denoise some test images. Let's load the whole tiff stack first -""" -# %% -path_test_data = "data/SEM/test/test.tif" -test_imgs = imread(path_test_data) -axes = "YX" - -# separate out the high SNR image as before -x_test, x_test_highSNR = test_imgs[:-1], test_imgs[-1] - - -# %% [markdown] -""" ---- -

- TASK 2.4:

-

- Write a function that applies the model to one of the images in the tiff stack. Code to visualize the result by plotting the noisy image alongside the restored image as well as smaller crops of each is provided. -

-
-""" - - -# %% -###TODO### -def apply_on_test(predict_model, img_idx, plot=True): - """ - Apply the given model on the test image at the given index of the tiff stack. - Returns the noisy image, restored image and the scantime. - """ - # TODO: insert your code for prediction here - scantime = ... # get scantime for `img_idx`th image - img = ... # get `img_idx`th image - restored = ... # apply model to `img` - if plot: - img_crop = img[500:756, 200:456] - restored_crop = restored[500:756, 200:456] - x_test_highSNR_crop = x_test_highSNR[500:756, 200:456] - plt.figure(figsize=(20, 30)) - plot_some( - np.stack([img, restored, x_test_highSNR]), - np.stack([img_crop, restored_crop, x_test_highSNR_crop]), - cmap="gray_r", - title_list=[[scantime, "restored", scantime_highSNR]], - ) - return img, restored, scantime - - -# %% tags = ["solution"] -def apply_on_test(predict_model, img_idx, plot=True): - """ - Apply the given model on the test image at the given index of the tiff stack. - Returns the noisy image, restored image and the scantime. - """ - scantime = scantimes[img_idx] - img = x_test[img_idx, ...] - axes = "YX" - restored = predict_model.predict(img, axes) - if plot: - img_crop = img[500:756, 200:456] - restored_crop = restored[500:756, 200:456] - x_test_highSNR_crop = x_test_highSNR[500:756, 200:456] - plt.figure(figsize=(20, 30)) - plot_some( - np.stack([img, restored, x_test_highSNR]), - np.stack([img_crop, restored_crop, x_test_highSNR_crop]), - cmap="gray_r", - title_list=[[scantime, "restored", scantime_highSNR]], - ) - return img, restored, scantime - - -# %% [markdown] -""" ---- - -Using the function you just wrote to restore one of the images with 1us scan time. -""" -# %% -noisy_img, restored_img, scantime = apply_on_test(model, 2) - -ssi_input = structural_similarity(noisy_img, x_test_highSNR, data_range=65535) -ssi_restored = structural_similarity(restored_img, x_test_highSNR, data_range=65535) -print( - f"Structural similarity index (higher is better) wrt average of 4x5us images: \n" - f"Input: {ssi_input} \n" - f"Prediction: {ssi_restored}" -) - -psnr_input = peak_signal_noise_ratio(noisy_img, x_test_highSNR, data_range=65535) -psnr_restored = peak_signal_noise_ratio(restored_img, x_test_highSNR, data_range=65535) -print( - f"Peak signal-to-noise ratio wrt average of 4x5us images:\n" - f"Input: {psnr_input} \n" - f"Prediction: {psnr_restored}" -) - -# %% [markdown] -""" ---- -

- TASK 2.5:

-

- Be creative! - -Can you improve the results by using the data differently or by tweaking the settings? - -How could you train a single network to process all scan times? -

-
-""" - -# %% [markdown] -""" -To train a network to process all scan times use this instead as the solution to Task 2.3: -The names "low" and "GT" don't really fit here anymore, so use names "source_all" and "target_all" instead -""" -# %% -source_dir = "data/SEM/train/source_all" -target_dir = "data/SEM/train/target_all" -# %% -os.makedirs(source_dir, exist_ok=True) -os.makedirs(target_dir, exist_ok=True) - -# %% [markdown] -""" -Since we wanna train a network for all scan times, we will use all images as our input images. -To train Noise2Noise we can use every other image as our target - as long as the noise is different the only remianing structure is the signal, so mixing different scan times is totally fine. -Images are paired by having the same name in `source_dir` and `target_dir`. This means we'll have several copies of the same image with different names. These images aren't very big, so that's fine. -""" -# %% -counter = 0 -for i in range(x_train.shape[0]): - for j in range(x_train.shape[0]): - if i == j: - continue - imwrite(os.path.join(source_dir, f"{counter}.tif"), x_train[i, ...]) - imwrite(os.path.join(target_dir, f"{counter}.tif"), x_train[j, ...]) - counter += 1 - -# %% [markdown] -""" ---- -
-

- Congratulations!

-

- You have reached the second checkpoint of this exercise! Please mark your progress in the course chat! -

-
-""" diff --git a/pyscripts/solution3.py b/pyscripts/solution3.py deleted file mode 100644 index b04edd9..0000000 --- a/pyscripts/solution3.py +++ /dev/null @@ -1,316 +0,0 @@ -# %% [markdown] -""" -
- -# Train a Noise2Void network - -Both the CARE network and Noise2Noise network you trained in part 1 and 2 require that you acquire additional data for the purpose of denoising. For CARE we used a paired acquisition with high SNR, for Noise2Noise we had paired noisy acquisitions. -We will now train a Noise2Void network from single noisy images. - -This notebook uses a single image from the SEM data from the Noise2Noise notebook, but as you'll see in Task 3.1 if you brought your own raw data you should adapt the notebook to use that instead. - -We now use the [Noise2Void library](https://github.com/juglab/n2v) instead of csbdeep/care, but don't worry - they're pretty similar. - -
-Set your python kernel to 03_image_restoration_part2 -
-
-Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems. -
- ---- - -

- TASK 3.1

-

-This notebook uses a single image from the SEM data from the Noise2Noise notebook. - -If you brought your own raw data, use that instead! -The only requirement is that the noise in your data is pixel-independent and zero-mean. If you're unsure whether your data fulfills that requirement or you don't yet understand why it is necessary ask one of us to discuss! - -If you don't have suitable data of your own, feel free to find some online or ask your fellow course participants. You can however also stick with the SEM data provided here and compare the results to what you achieved with Noise2Noise in the previous part. -

-
- ---- -""" -# %% -# We import all our dependencies. -from n2v.models import N2VConfig, N2V -import numpy as np -from csbdeep.utils import plot_history -from n2v.utils.n2v_utils import manipulate_val_data -from n2v.internals.N2V_DataGenerator import N2V_DataGenerator -from matplotlib import pyplot as plt -import urllib -import os -from skimage.metrics import structural_similarity, peak_signal_noise_ratio -from tifffile import imread -import zipfile - -# %load_ext tensorboard - -import ssl - -ssl._create_default_https_context = ssl._create_unverified_context - -# %% [markdown] -""" -
- -## Part 1: Prepare data -Let's make sure the data is there! -""" -# %% -assert os.path.exists("data/SEM/train/train.tif") -assert os.path.exists("data/SEM/test/test.tif") -# %% [markdown] -""" -We create a N2V_DataGenerator object to help load data and extract patches for training and validation. -""" -# %% -datagen = N2V_DataGenerator() -# %% [markdown] -""" -The data generator provides two methods for loading data: `load_imgs_from_directory` and `load_imgs`. Let's look at their docstring to figure out how to use it. -""" -# %% -# ?N2V_DataGenerator.load_imgs_from_directory -# %% -# ?N2V_DataGenerator.load_imgs -# %% [markdown] -""" -The SEM images are all in one directory, so we'll use `load_imgs_from_directory`. We'll pass in that directory (`"data/SEM/train"`), our image matches the default filter (`"*.tif"`) so we do not need to specify that. But our tif image is a stack of several images, so as dims we need to specify `"TYX"`. -If you're using your own data adapt this part to match your use case. If these functions aren't suitable for your use case load your images manually. -Feel free to ask a TA for help if you're unsure how to get your data loaded! -""" -# %% -imgs = datagen.load_imgs_from_directory("data/SEM/train", dims="TYX") -print(f"Loaded {len(imgs)} images.") -print(f"First image has shape {imgs[0].shape}") -# %% [markdown] -""" -The method returned a list of images, as per the doc string the dimensions of each are "SYXC". However, we only want to use one of the images here since Noise2Void is designed to work with just one acquisition of the sample. Let's use the first image at $1\mu s$ scantime. -""" -# %% -imgs = [img[2:3, :, :, :] for img in imgs] -print(f"First image has shape {imgs[0].shape}") -# %% [markdown] -""" -For generating patches the datagenerator provides the methods `generate_patches` and `generate_patches_from_list`. As before, let's have a quick look at the docstring -""" -# %% -# ?N2V_DataGenerator.generate_patches -# %% -# ?N2V_DataGenerator.generate_patches_from_list -# %% -type(imgs) -# %% [markdown] -""" -Our `imgs` object is a list, so `generate_patches_from_list` is the suitable function. -""" -# %% -patches = datagen.generate_patches_from_list(imgs, shape=(96, 96)) -# %% -# split into training and validation -n_train = int(round(0.9 * patches.shape[0])) -X, X_val = patches[:n_train, ...], patches[n_train:, ...] -# %% [markdown] -""" -As per usual, let's look at a training and validation patch to make sure everything looks okay. -""" -# %% -plt.figure(figsize=(14, 7)) -plt.subplot(1, 2, 1) -plt.imshow(X[np.random.randint(X.shape[0]), ..., 0], cmap="gray_r") -plt.title("Training patch") -plt.subplot(1, 2, 2) -plt.imshow(X_val[np.random.randint(X_val.shape[0]), ..., 0], cmap="gray_r") -plt.title("Validation patch") -# %% [markdown] -""" -
- -## Part 2: Configure and train the Noise2Void Network - -Noise2Void comes with a special config-object, where we store network-architecture and training specific parameters. See the docstring of the N2VConfig constructor for a description of all parameters. - -When creating the config-object, we provide the training data X. From X the library will extract mean and std that will be used to normalize all data before it is processed by the network. - - -Compared to supervised training (i.e. traditional CARE), we recommend to use N2V with an increased train_batch_size (e.g. 128) and batch_norm. - -To keep the network from learning the identity we have to manipulate the input pixels for the blindspot during training. How to exactly manipulate those values is controlled via the n2v_manipulator parameter with default value 'uniform_withCP' which samples a random value from the surrounding pixels, including the value at the control point. The size of the surrounding area can be configured via n2v_neighborhood_radius. - -The [paper supplement](https://arxiv.org/src/1811.10980v2/anc/supp_small.pdf) describes other pixel manipulators as well (section 3.1). If you want to configure one of those use the following values for n2v_manipulator: -* "normal_additive" for Gaussian (n2v_neighborhood_radius will set sigma) -* "normal_fitted" for Gaussian Fitting -* "normal_withoutCP" for Gaussian Pixel Selection - -For faster training multiple pixels per input patch can be manipulated. In our experiments we manipulated about 0.198% of the input pixels per patch. For a patch size of 64 by 64 pixels this corresponds to about 8 pixels. This fraction can be tuned via n2v_perc_pix. - -For Noise2Void training it is possible to pass arbitrarily large patches to the training method. From these patches random subpatches of size n2v_patch_shape are extracted during training. Default patch shape is set to (64, 64). - -In the past we experienced bleedthrough artifacts between channels if training was terminated to early. To counter bleedthrough we added the `single_net_per_channel` option, which is turned on by default. In the back a single U-Net for each channel is created and trained independently, thereby removing the possiblity of bleedthrough.
-Essentially the network gets multiplied by the number of channels, which increases the memory requirements. If your GPU gets too small, you can always split the channels manually and train a network for each channel one after another. - ---- -

- TASK 3.2

-

-As suggested look at the docstring of the N2VConfig and then generate a configuration for your Noise2Void network, and choose a name to identify your model by. -

-
-""" -# %% -# ?N2VConfig -# %% -###TODO### -config = N2VConfig() -vars(config) -model_name = "" -# %% tags=["solution"] -# train_steps_per_epoch is set to (number of training patches)/(batch size), like this each training patch -# is shown once per epoch. -config = N2VConfig( - X, - unet_kern_size=3, - train_steps_per_epoch=int(X.shape[0] / 128), - train_epochs=200, - train_loss="mse", - batch_norm=True, - train_batch_size=128, - n2v_perc_pix=0.198, - n2v_patch_shape=(64, 64), - n2v_manipulator="uniform_withCP", - n2v_neighborhood_radius=5, -) - -# Let's look at the parameters stored in the config-object. -vars(config) -model_name = "n2v_2D" - -# %% [markdown] -""" ---- -""" -# %% -# initialize the model -model = N2V(config, model_name, basedir="models") -# %% [markdown] -""" -Now let's train the model and monitor the progress in tensorboard. -Adapt the command below as you did before. -""" -# %% -# %tensorboard --logdir=models -# %% -history = model.train(X, X_val) -# %% -print(sorted(list(history.history.keys()))) -plt.figure(figsize=(16, 5)) -plot_history(history, ["loss", "val_loss"]) -# %% [markdown] -""" -
- -## Part 3: Prediction - -Similar to CARE a previously trained model is loaded by creating a new N2V-object without providing a `config`. -""" -# %% -model = N2V(config=None, name=model_name, basedir="models") -# %% [markdown] -""" -Let's load a $1\mu s$ scantime test images and denoise them using our network and like before we'll use the high SNR image to make a quantitative comparison. If you're using your own data and don't have an equivalent you can ignore that part. -""" -# %% -test_img = imread("data/SEM/test/test.tif")[2, ...] -test_img_highSNR = imread("data/SEM/test/test.tif")[-1, ...] -print(f"Loaded test image with shape {test_img.shape}") -# %% -test_denoised = model.predict(test_img, axes="YX", n_tiles=(2, 1)) -# %% [markdown] -""" -Let's look at the results -""" -# %% -plt.figure(figsize=(30, 30)) -plt.subplot(2, 3, 1) -plt.imshow(test_img, cmap="gray_r") -plt.title("Noisy test image") -plt.subplot(2, 3, 4) -plt.imshow(test_img[2000:2200, 500:700], cmap="gray_r") -plt.subplot(2, 3, 2) -plt.imshow(test_denoised, cmap="gray_r") -plt.title("Denoised test image") -plt.subplot(2, 3, 5) -plt.imshow(test_denoised[2000:2200, 500:700], cmap="gray_r") -plt.subplot(2, 3, 3) -plt.imshow(test_img_highSNR, cmap="gray_r") -plt.title("High SNR image (4x5us)") -plt.subplot(2, 3, 6) -plt.imshow(test_img_highSNR[2000:2200, 500:700], cmap="gray_r") -plt.show() -# %% [markdown] -""" ---- -

- TASK 3.3

-

- -If you're using the SEM data (or happen to have a high SNR version of the image you predicted from) compare the structural similarity index and peak signal to noise ratio (wrt the high SNR image) of the noisy input image and the predicted image. If not, just skip this task. -

-
-""" -# %% -###TODO### -ssi_input = ... # TODO -ssi_restored = ... # TODO -print( - f"Structural similarity index (higher is better) wrt average of 4x5us images: \n" - f"Input: {ssi_input} \n" - f"Prediction: {ssi_restored}" -) -psnr_input = ... # TODO -psnr_restored = ... # TODO -print( - f"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\n" - f"Input: {psnr_input} \n" - f"Prediction: {psnr_restored}" -) - -# %% tags = ["solution"] -ssi_input = structural_similarity(test_img, test_img_highSNR, data_range=65535) -ssi_restored = structural_similarity(test_denoised, test_img_highSNR, data_range=65535) -print( - f"Structural similarity index (higher is better) wrt average of 4x5us images: \n" - f"Input: {ssi_input} \n" - f"Prediction: {ssi_restored}" -) -psnr_input = peak_signal_noise_ratio(test_img, test_img_highSNR, data_range=65535) -psnr_restored = peak_signal_noise_ratio( - test_denoised, test_img_highSNR, data_range=65535 -) -print( - f"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\n" - f"Input: {psnr_input} \n" - f"Prediction: {psnr_restored}" -) -# %% [markdown] -""" ---- -
-

- Congratulations!

-

- You have reached the third checkpoint of this exercise! Please mark your progress in the course chat! -

-

- Consider sharing some pictures of your results on element, especially if you used your own data. -

-

- If there's still time, check out the bonus exercise. -

-
-""" diff --git a/pyscripts/solution_bonus1.py b/pyscripts/solution_bonus1.py deleted file mode 100644 index 157f377..0000000 --- a/pyscripts/solution_bonus1.py +++ /dev/null @@ -1,632 +0,0 @@ -# %% [markdown] -""" -
- -# Train Probabilistic Noise2Void - -Probabilistic Noise2Void, just as N2V, allows training from single noisy images. - -In order to get some additional quality squeezed out of your noisy input data, PN2V employs an additional noise model which can either be measured directly at your microscope or approximated by a process called ‘bootstrapping’. -Below we will give you a noise model for the first network to train and then bootstrap one, so you can apply PN2V to your own data if you'd like. - -Note: The PN2V implementation is written in pytorch, not Keras/TF. - -Note: PN2V experienced multiple updates regarding noise model representations. Hence, the [original PN2V repository](https://github.com/juglab/pn2v) is not any more the one we suggest to use (despite it of course working just as described in the original publication). So here we use the [PPN2V repo](https://github.com/juglab/PPN2V) which you installed during setup. - -
-Set your python kernel to 03_image_restoration_bonus -
-
-Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems. -
- -""" -# %% -import warnings - -warnings.filterwarnings("ignore") -import torch - -dtype = torch.float -device = torch.device("cuda:0") -from torch.distributions import normal -import matplotlib.pyplot as plt, numpy as np, pickle -from scipy.stats import norm -from tifffile import imread -import sys -import os -import urllib -import zipfile - -# %% -from ppn2v.pn2v import histNoiseModel, gaussianMixtureNoiseModel -from ppn2v.pn2v.utils import plotProbabilityDistribution, PSNR -from ppn2v.unet.model import UNet -from ppn2v.pn2v import training, prediction - -# %% [markdown] -""" -## Data Preperation - -Here we use a dataset of 2D images of fluorescently labeled membranes of Convallaria (lilly of the valley) acquired with a spinning disk microscope. -All 100 recorded images (1024×1024 pixels) show the same region of interest and only differ in their noise. -""" - -# %% -# Check that data download was successful -assert os.path.exists("data/Convallaria_diaphragm") - - -# %% -path = "data/Convallaria_diaphragm/" -data_name = "convallaria" # Name of the noise model -calibration_fn = "20190726_tl_50um_500msec_wf_130EM_FD.tif" -noisy_fn = "20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif" -noisy_imgs = imread(path + noisy_fn) -calibration_imgs = imread(path + calibration_fn) - -# %% [markdown] -""" -This notebook has a total of four options to generate a noise model for PN2V. You can pick which one you would like to use (and ignore the tasks in the options you don't wanna use)! - -There are two types of noise models for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM). -For both we need to provide a clean signal as groundtruth. For the dataset we have here we have calibration data available so you can choose between using the calibration data or bootstrapping the model by training a N2V network. -""" -# %% -n_gaussian = 3 # Number of gaussians to use for Gaussian Mixture Model -n_coeff = 2 # No. of polynomial coefficients for parameterizing the mean, standard deviation and weight of Gaussian components. - -# %% [markdown] -""" -
- -## Choice 1: Generate a Noise Model using Calibration Data -The noise model is a characteristic of your camera. The downloaded data folder contains a set of calibration images (For the Convallaria dataset, it is ```20190726_tl_50um_500msec_wf_130EM_FD.tif``` and the data to be denoised is named ```20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif```). We can either bin the noisy - GT pairs (obtained from noisy calibration images) as a 2-D histogram or fit a GMM distribution to obtain a smooth, parametric description of the noise model. - -We will use pairs of noisy calibration observations $x_i$ and clean signal $s_i$ (created by averaging these noisy, calibration images) to estimate the conditional distribution $p(x_i|s_i)$. Histogram-based and Gaussian Mixture Model-based noise models are generated and saved. -""" -# %% -name_hist_noise_model_cal = "_".join(["HistNoiseModel", data_name, "calibration"]) -name_gmm_noise_model_cal = "_".join( - ["GMMNoiseModel", data_name, str(n_gaussian), str(n_coeff), "calibration"] -) -# %% [markdown] -""" ---- -

- TASK 4.1

-

- -The calibration data contains 100 images of a static sample. Estimate the clean signal by averaging all the images. -

-
-""" -# %% -###TODO### -# Average the images in `calibration_imgs` -signal_cal = ... # TODO - - -# %% tags = ["solution"] -# Average the images in `calibration_imgs` -signal_cal = np.mean(calibration_imgs[:, ...], axis=0)[np.newaxis, ...] -# %% [markdown] -""" -Let's visualize a single image from the observation array alongside the average to see how the raw data compares to the pseudo ground truth signal. -""" -# %% [markdown] -""" ---- -""" -# %% -plt.figure(figsize=(12, 12)) -plt.subplot(1, 2, 2) -plt.title(label="average (ground truth)") -plt.imshow(signal_cal[0], cmap="gray") -plt.subplot(1, 2, 1) -plt.title(label="single raw image") -plt.imshow(calibration_imgs[0], cmap="gray") -plt.show() - - -# %% -# The subsequent code expects the signal array to have a dimension for the samples -if signal_cal.shape == calibration_imgs.shape[1:]: - signal_cal = signal_cal[np.newaxis, ...] - -# %% [markdown] -""" -There are two ways of generating a noise model for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM). You can pick which one you wanna use! - -
- -### Choice 1A: Creating the Histogram Noise Model -Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a histogram based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$. - ---- -

- TASK 4.2

-

- Look at the docstring for createHistogram and use it to create a histogram based on the calibration data using the clean signal you created by averaging as groundtruth.

-
-""" -# %% -# ?histNoiseModel.createHistogram - -# %% -###TODO### -# Define the parameters for the histogram creation -bins = 256 -# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range -min_val = ... # TODO -max_val = ... # TODO -# Create the histogram -histogram_cal = histNoiseModel.createHistogram(bins, ...) # TODO - -# %% tags = ["solution"] -# Define the parameters for the histogram creation -bins = 256 -# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range -min_val = 234 # np.min(noisy_imgs) -max_val = 7402 # np.max(noisy_imgs) -print("min:", min_val, ", max:", max_val) -# Create the histogram -histogram_cal = histNoiseModel.createHistogram( - bins, min_val, max_val, calibration_imgs, signal_cal -) -# %% [markdown] -""" ---- -""" -# %% -# Saving histogram to disk. -np.save(path + name_hist_noise_model_cal + ".npy", histogram_cal) -histogramFD_cal = histogram_cal[0] - -# %% -# Let's look at the histogram-based noise model. -plt.xlabel("Observation Bin") -plt.ylabel("Signal Bin") -plt.imshow(histogramFD_cal**0.25, cmap="gray") -plt.show() - -# %% [markdown] -""" -
- -### Choice 1B: Creating the GMM noise model -Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$. -""" -# %% -min_signal = np.min(signal_cal) -max_signal = np.max(signal_cal) -print("Minimum Signal Intensity is", min_signal) -print("Maximum Signal Intensity is", max_signal) - -# %% [markdown] -""" -Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset. -""" -# %% -# ?gaussianMixtureNoiseModel.GaussianMixtureNoiseModel -# %% -gmm_noise_model_cal = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel( - min_signal=min_signal, - max_signal=max_signal, - path=path, - weight=None, - n_gaussian=n_gaussian, - n_coeff=n_coeff, - min_sigma=50, - device=device, -) -# %% -gmm_noise_model_cal.train( - signal_cal, - calibration_imgs, - batchSize=250000, - n_epochs=2000, - learning_rate=0.1, - name=name_gmm_noise_model_cal, -) -# %% [markdown] -""" -
- -### Visualizing the Histogram-based and GMM-based noise models - -This only works if you generated both a histogram (Choice 1A) and GMM-based (Choice 1B) noise model -""" -# %% -plotProbabilityDistribution( - signalBinIndex=170, - histogram=histogramFD_cal, - gaussianMixtureNoiseModel=gmm_noise_model_cal, - min_signal=min_val, - max_signal=max_val, - n_bin=bins, - device=device, -) -# %% [markdown] -""" -
- -## Choice 2: Generate a Noise Model by Bootstrapping - -Here we bootstrap a suitable histogram noise model and a GMM noise model after denoising the noisy images with Noise2Void and then using these denoised images as pseudo GT. -So first, we need to train a N2V model (now with pytorch) to estimate the conditional distribution $p(x_i|s_i)$. No additional calibration data is used for bootstrapping (so no need to use `calibration_imgs` or `singal_cal` again). -""" -# %% -model_name = data_name + "_n2v" -name_hist_noise_model_bootstrap = "_".join(["HistNoiseModel", data_name, "bootstrap"]) -name_gmm_noise_model_bootstrap = "_".join( - ["GMMNoiseModel", data_name, str(n_gaussian), str(n_coeff), "bootstrap"] -) - -# %% -# Configure the Noise2Void network -n2v_net = UNet(1, depth=3) - -# %% -# Prepare training+validation data -train_data = noisy_imgs[:-5].copy() -val_data = noisy_imgs[-5:].copy() -np.random.shuffle(train_data) -np.random.shuffle(val_data) - -# %% -train_history, val_history = training.trainNetwork( - net=n2v_net, - trainData=train_data, - valData=val_data, - postfix=model_name, - directory=path, - noiseModel=None, - device=device, - numOfEpochs=200, - stepsPerEpoch=10, - virtualBatchSize=20, - batchSize=1, - learningRate=1e-3, -) - -# %% -# Let's look at the training and validation loss -plt.xlabel("epoch") -plt.ylabel("loss") -plt.plot(val_history, label="validation loss") -plt.plot(train_history, label="training loss") -plt.legend() -plt.show() - -# %% -# We now run the N2V model to create pseudo groundtruth. -n2v_result_imgs = [] -n2v_input_imgs = [] - -for index in range(noisy_imgs.shape[0]): - im = noisy_imgs[index] - # We are using tiling to fit the image into memory - # If you get an error try a smaller patch size (ps) - n2v_pred = prediction.tiledPredict( - im, n2v_net, ps=256, overlap=48, device=device, noiseModel=None - ) - n2v_result_imgs.append(n2v_pred) - n2v_input_imgs.append(im) - if index % 10 == 0: - print("image:", index) - -# %% -# In bootstrap mode, we estimate pseudo GT by using N2V denoised images. -signal_bootstrap = np.array(n2v_result_imgs) -# Let's look the raw data and our pseudo ground truth signal -print(signal_bootstrap.shape) -plt.figure(figsize=(12, 12)) -plt.subplot(2, 2, 2) -plt.title(label="pseudo GT (generated by N2V denoising)") -plt.imshow(signal_bootstrap[0], cmap="gray") -plt.subplot(2, 2, 4) -plt.imshow(signal_bootstrap[0, -128:, -128:], cmap="gray") -plt.subplot(2, 2, 1) -plt.title(label="single raw image") -plt.imshow(noisy_imgs[0], cmap="gray") -plt.subplot(2, 2, 3) -plt.imshow(noisy_imgs[0, -128:, -128:], cmap="gray") -plt.show() -# %% [markdown] -""" -Now that we have pseudoGT, you can pick again between a histogram based noise model and a GMM noise model - -
- -### Choice 2A: Creating the Histogram Noise Model - ---- -

- TASK 4.3

-

- If you've already done Task 4.2, this is very similar! - Look at the docstring for createHistogram and use it to create a histogram using the bootstraped signal you created from the N2V predictions. -

-
-""" -# %% -# ?histNoiseModel.createHistogram -# %% -###TODO### -# Define the parameters for the histogram creation -bins = 256 -# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range -min_val = ... # TODO -max_val = ... # TODO -# Create the histogram -histogram_bootstrap = histNoiseModel.createHistogram(bins, ...) # TODO -# %% tags=["solution"] -# Define the parameters for the histogram creation -bins = 256 -# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range -min_val = np.min(noisy_imgs) -max_val = np.max(noisy_imgs) -# Create the histogram -histogram_bootstrap = histNoiseModel.createHistogram( - bins, min_val, max_val, noisy_imgs, signal_bootstrap -) -# %% [markdown] -""" ---- -""" -# %% -# Saving histogram to disk. -np.save(path + name_hist_noise_model_bootstrap + ".npy", histogram_bootstrap) -histogramFD_bootstrap = histogram_bootstrap[0] -# %% -# Let's look at the histogram-based noise model -plt.xlabel("Observation Bin") -plt.ylabel("Signal Bin") -plt.imshow(histogramFD_bootstrap**0.25, cmap="gray") -plt.show() - -# %% [markdown] -""" -
- -### Choice 2B: Creating the GMM noise model -Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$. -""" -# %% -min_signal = np.percentile(signal_bootstrap, 0.5) -max_signal = np.percentile(signal_bootstrap, 99.5) -print("Minimum Signal Intensity is", min_signal) -print("Maximum Signal Intensity is", max_signal) -# %% [markdown] -""" -Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset. -""" -# %% -gmm_noise_model_bootstrap = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel( - min_signal=min_signal, - max_signal=max_signal, - path=path, - weight=None, - n_gaussian=n_gaussian, - n_coeff=n_coeff, - device=device, - min_sigma=50, -) -# %% -gmm_noise_model_bootstrap.train( - signal_bootstrap, - noisy_imgs, - batchSize=250000, - n_epochs=2000, - learning_rate=0.1, - name=name_gmm_noise_model_bootstrap, - lowerClip=0.5, - upperClip=99.5, -) -# %% [markdown] -""" -### Visualizing the Histogram-based and GMM-based noise models - -This only works if you generated both a histogram (Choice 2A) and GMM-based (Choice 2B) noise model -""" -# %% -plotProbabilityDistribution( - signalBinIndex=170, - histogram=histogramFD_bootstrap, - gaussianMixtureNoiseModel=gmm_noise_model_bootstrap, - min_signal=min_val, - max_signal=max_val, - n_bin=bins, - device=device, -) -# %% [markdown] -""" -
- -## PN2V Training - ---- -

- TASK 4.4

-

- Adapt to use the noise model of your choice here to then train PN2V with. -

-
-""" -# %% -###TODO### -noise_model_type = "gmm" # pick: "hist" or "gmm" -noise_model_data = "bootstrap" # pick: "calibration" or "bootstrap" - -# %% tags = ["solution"] -if noise_model_type == "hist": - noise_model_name = "_".join(["HistNoiseModel", data_name, noise_model_data]) - histogram = np.load(path + noise_model_name + ".npy") - noise_model = histNoiseModel.NoiseModel(histogram, device=device) -elif noise_model_type == "gmm": - noise_model_name = "_".join( - ["GMMNoiseModel", data_name, str(n_gaussian), str(n_coeff), noise_model_data] - ) - params = np.load(path + noise_model_name + ".npz") - noise_model = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel( - params=params, device=device - ) -# %% [markdown] -""" ---- -""" -# %% -# Create a network with 800 output channels that are interpreted as samples from the prior. -pn2v_net = UNet(800, depth=3) -# %% -# Start training. -trainHist, valHist = training.trainNetwork( - net=pn2v_net, - trainData=train_data, - valData=val_data, - postfix=noise_model_name, - directory=path, - noiseModel=noise_model, - device=device, - numOfEpochs=200, - stepsPerEpoch=5, - virtualBatchSize=20, - batchSize=1, - learningRate=1e-3, -) -# %% [markdown] -""" -
- -## PN2V Evaluation -""" -# %% -test_data = noisy_imgs[ - :, :512, :512 -] # We are loading only a sub image to speed up computation - -# %% -# We estimate the ground truth by averaging. -test_data_gt = np.mean(test_data[:, ...], axis=0)[np.newaxis, ...] - -# %% -pn2v_net = torch.load(path + "/last_" + noise_model_name + ".net") - -# %% -# Now we are processing data and calculating PSNR values. -mmse_psnrs = [] -prior_psnrs = [] -input_psnrs = [] -result_ims = [] -input_ims = [] - -# We iterate over all test images. -for index in range(test_data.shape[0]): - im = test_data[index] - gt = test_data_gt[0] # The ground truth is the same for all images - - # We are using tiling to fit the image into memory - # If you get an error try a smaller patch size (ps) - means, mse_est = prediction.tiledPredict( - im, pn2v_net, ps=192, overlap=48, device=device, noiseModel=noise_model - ) - - result_ims.append(mse_est) - input_ims.append(im) - - range_psnr = np.max(gt) - np.min(gt) - psnr = PSNR(gt, mse_est, range_psnr) - psnr_prior = PSNR(gt, means, range_psnr) - input_psnr = PSNR(gt, im, range_psnr) - mmse_psnrs.append(psnr) - prior_psnrs.append(psnr_prior) - input_psnrs.append(input_psnr) - - print("image:", index) - print("PSNR input", input_psnr) - print("PSNR prior", psnr_prior) # Without info from masked pixel - print("PSNR mse", psnr) # MMSE estimate using the masked pixel - print("-----------------------------------") - -# %% -# ?prediction.tiledPredict - -# %% -# We display the results for the last test image -vmi = np.percentile(gt, 0.01) -vma = np.percentile(gt, 99) - -plt.figure(figsize=(15, 15)) -plt.subplot(1, 3, 1) -plt.title(label="Input Image") -plt.imshow(im, vmax=vma, vmin=vmi, cmap="magma") - -plt.subplot(1, 3, 2) -plt.title(label="Avg. Prior") -plt.imshow(means, vmax=vma, vmin=vmi, cmap="magma") - -plt.subplot(1, 3, 3) -plt.title(label="PN2V-MMSE estimate") -plt.imshow(mse_est, vmax=vma, vmin=vmi, cmap="magma") -plt.show() - -plt.figure(figsize=(15, 15)) -plt.subplot(1, 3, 1) -plt.title(label="Input Image") -plt.imshow(im[100:200, 150:250], vmax=vma, vmin=vmi, cmap="magma") -plt.axhline(y=50, linewidth=3, color="white", alpha=0.5, ls="--") - -plt.subplot(1, 3, 2) -plt.title(label="Avg. Prior") -plt.imshow(means[100:200, 150:250], vmax=vma, vmin=vmi, cmap="magma") -plt.axhline(y=50, linewidth=3, color="white", alpha=0.5, ls="--") - -plt.subplot(1, 3, 3) -plt.title(label="PN2V-MMSE estimate") -plt.imshow(mse_est[100:200, 150:250], vmax=vma, vmin=vmi, cmap="magma") -plt.axhline(y=50, linewidth=3, color="white", alpha=0.5, ls="--") - - -plt.figure(figsize=(15, 5)) -plt.plot(im[150, 150:250], label="Input Image") -plt.plot(means[150, 150:250], label="Avg. Prior") -plt.plot(mse_est[150, 150:250], label="PN2V-MMSE estimate") -plt.plot(gt[150, 150:250], label="Pseudo GT by averaging") -plt.legend() - -plt.show() -print( - "Avg PSNR Prior:", - np.mean(np.array(prior_psnrs)), - "+-(2SEM)", - 2 * np.std(np.array(prior_psnrs)) / np.sqrt(float(len(prior_psnrs))), -) -print( - "Avg PSNR MMSE:", - np.mean(np.array(mmse_psnrs)), - "+-(2SEM)", - 2 * np.std(np.array(mmse_psnrs)) / np.sqrt(float(len(mmse_psnrs))), -) - -# %% [markdown] -""" ---- ---- -

- TASK 4.5

-

- Try PN2V for your own data! You probably don't have calibration data, but with the bootstrapping method you don't need any! -

-
- ---- - -
-

- Congratulations!

-

- You have completed the bonus exercise! -

-
-""" diff --git a/setup.sh b/setup.sh index 05646a7..acdcdce 100755 --- a/setup.sh +++ b/setup.sh @@ -1,43 +1,54 @@ -#!/bin/bash -i +#!/bin/bash -# activate base environment -mamba activate base +# create environment +ENV="05_image_restoration" +conda create -y -n "$ENV" python=3.10 +conda activate "$ENV" -# create a new environment called '03_image_restoration_part1' and initialize it with python version 3.7 -mamba create -y -n 03_image_restoration_part1 python=3.7 -# activate the environment -mamba activate 03_image_restoration_part1 -# install dependencies from conda -mamba install -y tensorflow-gpu keras jupyter tensorboard nb_conda scikit-image -# install dependencies from pip -pip install CSBDeep -# return to base environment -mamba activate base +# check that the environment was activated +if [[ "$CONDA_DEFAULT_ENV" == "$ENV" ]]; then + echo "Environment activated successfully" +else + echo "Failed to activate the environment" +fi -# create a new environment called '03_image_restoration_part2' -mamba create -y -n 03_image_restoration_part2 python=3.7 -# activate the environment -mamba activate 03_image_restoration_part2 -# install dependencies from conda -mamba install -y keras=2.3.1 tensorboard scikit-image nb_conda -# install dependencies from pip -pip install tensorflow-gpu==2.4.1 -pip install n2v -# return to base environment -mamba activate base +# Further instructions that should only run if the environment is active +if [[ "$CONDA_DEFAULT_ENV" == "$ENV" ]]; then + conda install -y pytorch-gpu cuda-toolkit=11.8 torchvision -c nvidia -c conda-forge -c pytorch + #mamba install -y pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia + pip install jupytext black nbconvert albumentations ml_collections wandb scikit-learn ipykernel gdown "careamics[examples,tensorboard] @ git+https://github.com/CAREamics/careamics.git" + # Using pytorch-lightning 2.4.0 causes bugs in tensorboard and interupting training. + pip install pytorch-lightning==2.3.3 + pip install git+https://github.com/dlmbl/dlmbl-unet + python -m ipykernel install --user --name "05_image_restoration" + # Clone the extra repositories + git clone https://github.com/krulllab/COSDD.git -b n_dimensional 03_COSDD/COSDD + git clone https://github.com/juglab/denoiSplit.git 04_DenoiSplit/denoisplit -# create a new environment called '03_image_restoration_bonus' -mamba create -y -n 03_image_restoration_bonus python=3.9 -# activate the environment -mamba activate 03_image_restoration_bonus -# install pytorch depencencies -mamba install -y pytorch torchvision pytorch-cuda=11.8 'numpy<1.23' scipy matplotlib tifffile jupyter nb_conda_kernels -c pytorch -c nvidia -# install PPN2V repo from github -pip install git+https://github.com/juglab/PPN2V.git -# activate base environment -mamba activate base + # Download the data + python download_careamics_portfolio.py + cd data/ + wget "https://s3.ap-northeast-1.wasabisys.com/gigadb-datasets/live/pub/10.5524/100001_101000/100888/03-mito-confocal/mito-confocal-lowsnr.tif" + mkdir CCPs/ + cd CCPs/ + gdown 16oiMkH3cpVU500MSPbm7ghOpEMoD2YNu + cd ../ + mkdir ER/ + cd ER/ + gdown 1Bho6Oymfxi7OV0tPb9wkINkVOCpTaL7M + cd ../ + mkdir Microtubules/ + cd Microtubules/ + gdown 14sPIEE2qU2J6oRFMz46v2IvkCVjFX8D1 + cd ../ + mkdir F-actin/ + cd F-actin/ + gdown 1FYO-Bpl5vjpiJ6kzV1qO1pL37Y3Dirfy + cd ../../ + mkdir 03_COSDD/checkpoints + cd 03_COSDD/checkpoints + gdown --folder 1_oUAxagFVin71xFASb9oLF6pz20HjqTr + cd ../../ +fi -# Data download -wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/image_restoration_data.zip -unzip -q image_restoration_data.zip \ No newline at end of file diff --git a/solution2.ipynb b/solution2.ipynb deleted file mode 100644 index f030d52..0000000 --- a/solution2.ipynb +++ /dev/null @@ -1,916 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a3ef0059", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train a Noise2Noise network with CARE\n", - "
\n", - "Set your python kernel to 03_image_restoration_part1! That's the same as for the first notebook.\n", - "
\n", - "\n", - "We will now train a 2D Noise2Noise network using CARE. We will closely follow along the previous example but now you will have to fill in some parts on your own!\n", - "You will have to make decisions - make them!\n", - "\n", - "But first some clean up...\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
\n", - "\n", - "![](nb_material/notebook_shutdown.png)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cd71b96a", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "import gc\n", - "import os\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from csbdeep.data import RawData, create_patches\n", - "from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n", - "from csbdeep.models import CARE, Config\n", - "from csbdeep.utils import (\n", - " Path,\n", - " axes_dict,\n", - " plot_history,\n", - " plot_some,\n", - ")\n", - "from csbdeep.utils.tf import limit_gpu_memory\n", - "\n", - "%matplotlib inline\n", - "%load_ext tensorboard\n", - "%config InlineBackend.figure_format = 'retina'\n", - "from skimage.metrics import peak_signal_noise_ratio, structural_similarity\n", - "from tifffile import imread, imwrite" - ] - }, - { - "cell_type": "markdown", - "id": "a04d9ec0", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 1: Training Data Generation\n", - "\n", - "### Download example data\n", - "\n", - "To train a Noise2Noise setup we need several acquisitions of the same sample.\n", - "The SEM data we downloaded during setup contains 2 tiff-stacks, one for training and one for testing, let's make sure it's there!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3eacfb41", - "metadata": {}, - "outputs": [], - "source": [ - "assert os.path.exists(\"data/SEM/train/train.tif\")\n", - "assert os.path.exists(\"data/SEM/test/test.tif\")" - ] - }, - { - "cell_type": "markdown", - "id": "2a486bc3", - "metadata": {}, - "source": [ - "Let's have a look at the data!\n", - "Each image is a tiff stack containing 7 images of the same tissue recorded with different scan time settings of a Scanning Electron Miscroscope (SEM). The faster a SEM image is scanned, the noisier it gets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5fbcf59e", - "metadata": {}, - "outputs": [], - "source": [ - "imgs = imread(\"data/SEM/train/train.tif\")\n", - "x_size = imgs.shape\n", - "print(\"image size =\", x_size)\n", - "scantimes_all = [\"0.2us\", \"0.5us\", \"1us\", \"1us\", \"2.1us\", \"5us\", \"5us, avg of 4\"]\n", - "plt.figure(figsize=(40, 16))\n", - "plot_some(imgs, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap=\"gray_r\")" - ] - }, - { - "cell_type": "markdown", - "id": "e13f36f4", - "metadata": {}, - "source": [ - "---\n", - "

\n", - " TASK 2.1:

\n", - "

\n", - " The noise level is hard to see at this zoom level. Let's also look at a smaller crop of them! Play around with this until you have a feeling for what the data looks like.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86ddce74", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "\n", - "imgs_cropped = ... # TODO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16631686", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "imgs_cropped = imgs[:, 1000:1128, 600:728]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "59a780db", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(40, 16))\n", - "plot_some(imgs_cropped, title_list=[scantimes_all], pmin=0.2, pmax=99.8, cmap=\"gray_r\")" - ] - }, - { - "cell_type": "markdown", - "id": "d0223757", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "249ff869", - "metadata": {}, - "outputs": [], - "source": [ - "# checking that you didn't crop x_train itself, we still need that!\n", - "assert imgs.shape == x_size" - ] - }, - { - "cell_type": "markdown", - "id": "97253add", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "As you can see the last image, which is the average of 4 images with 5$\\mu s$ scantime, has the highest signal-to-noise-ratio. It is not noise-free but our best choice to be able to compare our results against quantitatively, so we will set it aside for that purpose." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "82165898", - "metadata": {}, - "outputs": [], - "source": [ - "scantimes, scantime_highSNR = scantimes_all[:-1], scantimes_all[-1]\n", - "x_train, x_highSNR = imgs[:-1], imgs[-1]\n", - "print(scantimes, scantime_highSNR)\n", - "print(x_train.shape, x_highSNR.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "c904033d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Generate training data for CARE\n", - "\n", - "Let's try and train a network to denoise images of $1 \\mu s$ scan time!\n", - "Which images do you think could be used as input and which as target?\n", - "\n", - "---\n", - "

\n", - " TASK 2.2:

\n", - "

\n", - " Decide which images to use as inputs and which as targets. Then, remember from part one how the data has to be organized to match up inputs and targets.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5dce0b0b", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "base_path = \"data/SEM/train\"\n", - "source_dir = os.path.join(base_path, \"\") # pick path in which to save inputs\n", - "target_dir = os.path.join(base_path, \"\") # pick path in which to save targets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6272dd55", - "metadata": { - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# The names \"low\" and \"GT\" don't really fit here anymore, so use names \"source\" and \"target\" instead\n", - "base_path = \"data/SEM/train\"\n", - "source_dir = os.path.join(base_path, \"source_1us\")\n", - "target_dir = os.path.join(base_path, \"target_1us\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d9b0181", - "metadata": {}, - "outputs": [], - "source": [ - "os.makedirs(source_dir, exist_ok=True)\n", - "os.makedirs(target_dir, exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92fff631", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Now save individual images into these directories\n", - "# You can use the imwrite function to save images. The ? command will pull up the docstring\n", - "?imwrite" - ] - }, - { - "cell_type": "markdown", - "id": "f426a521", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Hint: The tiff file you read earlier contained 7 images for the different instances. Here, use a single tiff file per image." - ] - }, - { - "cell_type": "markdown", - "id": "ac8d428c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Hint: Remember we're trying to train a Noise2Noise network here, so the target does not need to be clean." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8701a98", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "\n", - "# Put the pairs of input and target images into the `source_dir` and `target_dir`, respectively.\n", - "# The goal here is to the train a network for 1 us scan time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "77d66a60", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# Since we wanna train a network for images of 1us scan time, we will use the two images as our input images.\n", - "# For both of these images we can use every other image as our target - as long as the noise is different the\n", - "# only remaining structure is the signal, so mixing different scan times is totally fine.\n", - "# Images are paired by having the same name in `source_dir` and `target_dir`. This means we'll have several\n", - "# copies of the same image with different names. These images aren't very big, so that's fine.\n", - "counter = 0\n", - "for i in range(2, 4):\n", - " for j in range(x_train.shape[0]):\n", - " if i == j:\n", - " continue\n", - " imwrite(os.path.join(source_dir, f\"{counter}.tif\"), x_train[i, ...])\n", - " imwrite(os.path.join(target_dir, f\"{counter}.tif\"), x_train[j, ...])\n", - " counter += 1" - ] - }, - { - "cell_type": "markdown", - "id": "dfc0f4ae", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "---\n", - "

\n", - " TASK 2.3:

\n", - "

\n", - " Now that you arranged the training data we can now create the raw data object.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f048fbce", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "raw_data = RawData.from_folder(\n", - " basepath=\"data/SEM/train\",\n", - " source_dirs=[\"\"], # fill in your directory for source images\n", - " target_dir=\"\", # fill in your directory of target images\n", - " axes=\"\", # what should the axes tag be?\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0a906e27", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "raw_data = RawData.from_folder(\n", - " basepath=\"data/SEM/train\",\n", - " source_dirs=[\"source_1us\"], # fill in your directory for source images\n", - " target_dir=\"target_1us\", # fill in your directory of target images\n", - " axes=\"YX\", # what should the axes tag be?\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "86a23463", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "We generate 2D patches. If you'd like, you can play around with the parameters here." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef0ee336", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y, XY_axes = create_patches(\n", - " raw_data=raw_data,\n", - " patch_size=(256, 256),\n", - " n_patches_per_image=512,\n", - " save_file=\"data/SEM/my_1us_training_data.npz\",\n", - ")\n", - "\n", - "assert X.shape == Y.shape\n", - "print(\"shape of X,Y =\", X.shape)\n", - "print(\"axes of X,Y =\", XY_axes)" - ] - }, - { - "cell_type": "markdown", - "id": "daf15a26", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Show\n", - "\n", - "Let's look at some of the generated patch pairs. (odd rows: _source_, even rows: _target_)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6227c8fe", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "for i in range(2):\n", - " plt.figure(figsize=(16, 4))\n", - " sl = slice(8 * i, 8 * (i + 1)), 0\n", - " plot_some(\n", - " X[sl], Y[sl], title_list=[np.arange(sl[0].start, sl[0].stop)], cmap=\"gray_r\"\n", - " )\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fbaf33e4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 2: Training the network\n", - "\n", - "\n", - "### Load Training data\n", - "\n", - "Load the patches generated in part 1, use 10% as validation data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef2231ad", - "metadata": {}, - "outputs": [], - "source": [ - "(X, Y), (X_val, Y_val), axes = load_training_data(\n", - " \"data/SEM/my_1us_training_data.npz\", validation_split=0.1, verbose=True\n", - ")\n", - "\n", - "c = axes_dict(axes)[\"C\"]\n", - "n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n", - "\n", - "\n", - "plt.figure(figsize=(12, 5))\n", - "plot_some(X_val[:5], Y_val[:5], cmap=\"gray_r\", pmin=0.2, pmax=99.8)\n", - "plt.suptitle(\"5 example validation patches (top row: source, bottom row: target)\")\n", - "\n", - "config = Config(\n", - " axes, n_channel_in, n_channel_out, train_steps_per_epoch=10, train_epochs=100\n", - ")\n", - "vars(config)" - ] - }, - { - "cell_type": "markdown", - "id": "c53ca47d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We now create a CARE model with the chosen configuration:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "386877f3", - "metadata": {}, - "outputs": [], - "source": [ - "model = CARE(config, \"my_N2N_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "4a170adb", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Training\n", - "\n", - "Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training.\n", - "Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.\n", - "\n", - "Start tensorboard as you did in the previous notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "383cc0fb", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "%tensorboard --logdir models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afd5ce7b", - "metadata": {}, - "outputs": [], - "source": [ - "history = model.train(X, Y, validation_data=(X_val, Y_val))" - ] - }, - { - "cell_type": "markdown", - "id": "242c2a9c", - "metadata": {}, - "source": [ - "Plot final training history (available in TensorBoard during training):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7607957c", - "metadata": {}, - "outputs": [], - "source": [ - "print(sorted(list(history.history.keys())))\n", - "plt.figure(figsize=(16, 5))\n", - "plot_history(history, [\"loss\", \"val_loss\"], [\"mse\", \"val_mse\", \"mae\", \"val_mae\"])" - ] - }, - { - "cell_type": "markdown", - "id": "a8b12c16", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Evaluation\n", - "Example results for validation images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7d920b92", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 7))\n", - "_P = model.keras_model.predict(X_val[:5])\n", - "if config.probabilistic:\n", - " _P = _P[..., : (_P.shape[-1] // 2)]\n", - "plot_some(X_val[:5], Y_val[:5], _P, pmin=0.2, pmax=99.8, cmap=\"gray_r\")\n", - "plt.suptitle(\n", - " \"5 example validation patches\\n\"\n", - " \"top row: input (noisy source), \"\n", - " \"mid row: target (independently noisy), \"\n", - " \"bottom row: predicted from source, \"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "72321ef2", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 3: Prediction\n", - "\n", - "\n", - "### Load CARE model\n", - "\n", - "Load trained model (located in base directory `models` with name `my_model`) from disk.\n", - "The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dbdb29ac", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = CARE(config=None, name=\"my_N2N_model\", basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "ee7ffaf8", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Apply CARE network to raw image\n", - "Now use the trained model to denoise some test images. Let's load the whole tiff stack first" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c6c2f73d", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "path_test_data = \"data/SEM/test/test.tif\"\n", - "test_imgs = imread(path_test_data)\n", - "axes = \"YX\"\n", - "\n", - "# separate out the high SNR image as before\n", - "x_test, x_test_highSNR = test_imgs[:-1], test_imgs[-1]" - ] - }, - { - "cell_type": "markdown", - "id": "0112bf1b", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "

\n", - " TASK 2.4:

\n", - "

\n", - " Write a function that applies the model to one of the images in the tiff stack. Code to visualize the result by plotting the noisy image alongside the restored image as well as smaller crops of each is provided.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbc45be7", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "def apply_on_test(predict_model, img_idx, plot=True):\n", - " \"\"\"\n", - " Apply the given model on the test image at the given index of the tiff stack.\n", - " Returns the noisy image, restored image and the scantime.\n", - " \"\"\"\n", - " # TODO: insert your code for prediction here\n", - " scantime = ... # get scantime for `img_idx`th image\n", - " img = ... # get `img_idx`th image\n", - " restored = ... # apply model to `img`\n", - " if plot:\n", - " img_crop = img[500:756, 200:456]\n", - " restored_crop = restored[500:756, 200:456]\n", - " x_test_highSNR_crop = x_test_highSNR[500:756, 200:456]\n", - " plt.figure(figsize=(20, 30))\n", - " plot_some(\n", - " np.stack([img, restored, x_test_highSNR]),\n", - " np.stack([img_crop, restored_crop, x_test_highSNR_crop]),\n", - " cmap=\"gray_r\",\n", - " title_list=[[scantime, \"restored\", scantime_highSNR]],\n", - " )\n", - " return img, restored, scantime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d447ae9e", - "metadata": { - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "def apply_on_test(predict_model, img_idx, plot=True):\n", - " \"\"\"\n", - " Apply the given model on the test image at the given index of the tiff stack.\n", - " Returns the noisy image, restored image and the scantime.\n", - " \"\"\"\n", - " scantime = scantimes[img_idx]\n", - " img = x_test[img_idx, ...]\n", - " axes = \"YX\"\n", - " restored = predict_model.predict(img, axes)\n", - " if plot:\n", - " img_crop = img[500:756, 200:456]\n", - " restored_crop = restored[500:756, 200:456]\n", - " x_test_highSNR_crop = x_test_highSNR[500:756, 200:456]\n", - " plt.figure(figsize=(20, 30))\n", - " plot_some(\n", - " np.stack([img, restored, x_test_highSNR]),\n", - " np.stack([img_crop, restored_crop, x_test_highSNR_crop]),\n", - " cmap=\"gray_r\",\n", - " title_list=[[scantime, \"restored\", scantime_highSNR]],\n", - " )\n", - " return img, restored, scantime" - ] - }, - { - "cell_type": "markdown", - "id": "770d410b", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "\n", - "Using the function you just wrote to restore one of the images with 1us scan time." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e780e06", - "metadata": {}, - "outputs": [], - "source": [ - "noisy_img, restored_img, scantime = apply_on_test(model, 2)\n", - "\n", - "ssi_input = structural_similarity(noisy_img, x_test_highSNR, data_range=65535)\n", - "ssi_restored = structural_similarity(restored_img, x_test_highSNR, data_range=65535)\n", - "print(\n", - " f\"Structural similarity index (higher is better) wrt average of 4x5us images: \\n\"\n", - " f\"Input: {ssi_input} \\n\"\n", - " f\"Prediction: {ssi_restored}\"\n", - ")\n", - "\n", - "psnr_input = peak_signal_noise_ratio(noisy_img, x_test_highSNR, data_range=65535)\n", - "psnr_restored = peak_signal_noise_ratio(restored_img, x_test_highSNR, data_range=65535)\n", - "print(\n", - " f\"Peak signal-to-noise ratio wrt average of 4x5us images:\\n\"\n", - " f\"Input: {psnr_input} \\n\"\n", - " f\"Prediction: {psnr_restored}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b268fafe", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "

\n", - " TASK 2.5:

\n", - "

\n", - " Be creative!\n", - "\n", - "Can you improve the results by using the data differently or by tweaking the settings?\n", - "\n", - "How could you train a single network to process all scan times?\n", - "

\n", - "
" - ] - }, - { - "cell_type": "markdown", - "id": "12de7fb3", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "To train a network to process all scan times use this instead as the solution to Task 2.3:\n", - "The names \"low\" and \"GT\" don't really fit here anymore, so use names \"source_all\" and \"target_all\" instead" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "87183177", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "source_dir = \"data/SEM/train/source_all\"\n", - "target_dir = \"data/SEM/train/target_all\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "675708bd", - "metadata": {}, - "outputs": [], - "source": [ - "os.makedirs(source_dir, exist_ok=True)\n", - "os.makedirs(target_dir, exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "id": "87127c2e", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Since we wanna train a network for all scan times, we will use all images as our input images.\n", - "To train Noise2Noise we can use every other image as our target - as long as the noise is different the only remianing structure is the signal, so mixing different scan times is totally fine.\n", - "Images are paired by having the same name in `source_dir` and `target_dir`. This means we'll have several copies of the same image with different names. These images aren't very big, so that's fine." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9ff3e56", - "metadata": {}, - "outputs": [], - "source": [ - "counter = 0\n", - "for i in range(x_train.shape[0]):\n", - " for j in range(x_train.shape[0]):\n", - " if i == j:\n", - " continue\n", - " imwrite(os.path.join(source_dir, f\"{counter}.tif\"), x_train[i, ...])\n", - " imwrite(os.path.join(target_dir, f\"{counter}.tif\"), x_train[j, ...])\n", - " counter += 1" - ] - }, - { - "cell_type": "markdown", - "id": "fbf87638", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have reached the second checkpoint of this exercise! Please mark your progress in the course chat!\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/solution3.ipynb b/solution3.ipynb deleted file mode 100644 index fa99ae9..0000000 --- a/solution3.ipynb +++ /dev/null @@ -1,704 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7d405b3a", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train a Noise2Void network\n", - "\n", - "Both the CARE network and Noise2Noise network you trained in part 1 and 2 require that you acquire additional data for the purpose of denoising. For CARE we used a paired acquisition with high SNR, for Noise2Noise we had paired noisy acquisitions.\n", - "We will now train a Noise2Void network from single noisy images.\n", - "\n", - "This notebook uses a single image from the SEM data from the Noise2Noise notebook, but as you'll see in Task 3.1 if you brought your own raw data you should adapt the notebook to use that instead.\n", - "\n", - "We now use the [Noise2Void library](https://github.com/juglab/n2v) instead of csbdeep/care, but don't worry - they're pretty similar.\n", - "\n", - "
\n", - "Set your python kernel to 03_image_restoration_part2\n", - "
\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
\n", - "\n", - "---\n", - "\n", - "

\n", - " TASK 3.1

\n", - "

\n", - "This notebook uses a single image from the SEM data from the Noise2Noise notebook.\n", - "\n", - "If you brought your own raw data, use that instead!\n", - "The only requirement is that the noise in your data is pixel-independent and zero-mean. If you're unsure whether your data fulfills that requirement or you don't yet understand why it is necessary ask one of us to discuss!\n", - "\n", - "If you don't have suitable data of your own, feel free to find some online or ask your fellow course participants. You can however also stick with the SEM data provided here and compare the results to what you achieved with Noise2Noise in the previous part.\n", - "

\n", - "
\n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f253352a", - "metadata": {}, - "outputs": [], - "source": [ - "# We import all our dependencies.\n", - "from n2v.models import N2VConfig, N2V\n", - "import numpy as np\n", - "from csbdeep.utils import plot_history\n", - "from n2v.utils.n2v_utils import manipulate_val_data\n", - "from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os\n", - "from skimage.metrics import structural_similarity, peak_signal_noise_ratio\n", - "from tifffile import imread\n", - "import zipfile\n", - "\n", - "%load_ext tensorboard\n", - "\n", - "import ssl\n", - "\n", - "ssl._create_default_https_context = ssl._create_unverified_context" - ] - }, - { - "cell_type": "markdown", - "id": "557ec582", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 1: Prepare data\n", - "Let's make sure the data is there!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ade8c11d", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "assert os.path.exists(\"data/SEM/train/train.tif\")\n", - "assert os.path.exists(\"data/SEM/test/test.tif\")" - ] - }, - { - "cell_type": "markdown", - "id": "0f458875", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "We create a N2V_DataGenerator object to help load data and extract patches for training and validation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "569e0c45", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "datagen = N2V_DataGenerator()" - ] - }, - { - "cell_type": "markdown", - "id": "90ef2146", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The data generator provides two methods for loading data: `load_imgs_from_directory` and `load_imgs`. Let's look at their docstring to figure out how to use it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e752db3", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.load_imgs_from_directory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3b68d54f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.load_imgs" - ] - }, - { - "cell_type": "markdown", - "id": "7cd57bd4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The SEM images are all in one directory, so we'll use `load_imgs_from_directory`. We'll pass in that directory (`\"data/SEM/train\"`), our image matches the default filter (`\"*.tif\"`) so we do not need to specify that. But our tif image is a stack of several images, so as dims we need to specify `\"TYX\"`.\n", - "If you're using your own data adapt this part to match your use case. If these functions aren't suitable for your use case load your images manually.\n", - "Feel free to ask a TA for help if you're unsure how to get your data loaded!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "289e03ce", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "imgs = datagen.load_imgs_from_directory(\"data/SEM/train\", dims=\"TYX\")\n", - "print(f\"Loaded {len(imgs)} images.\")\n", - "print(f\"First image has shape {imgs[0].shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "3bb5b8f5", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "The method returned a list of images, as per the doc string the dimensions of each are \"SYXC\". However, we only want to use one of the images here since Noise2Void is designed to work with just one acquisition of the sample. Let's use the first image at $1\\mu s$ scantime." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02c80399", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "imgs = [img[2:3, :, :, :] for img in imgs]\n", - "print(f\"First image has shape {imgs[0].shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "18841f39", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "For generating patches the datagenerator provides the methods `generate_patches` and `generate_patches_from_list`. As before, let's have a quick look at the docstring" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8acfd6f1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.generate_patches" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad205d8f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2V_DataGenerator.generate_patches_from_list" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1ffa91f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "type(imgs)" - ] - }, - { - "cell_type": "markdown", - "id": "4073063c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Our `imgs` object is a list, so `generate_patches_from_list` is the suitable function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd8ad59a", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "patches = datagen.generate_patches_from_list(imgs, shape=(96, 96))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf2fcfce", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# split into training and validation\n", - "n_train = int(round(0.9 * patches.shape[0]))\n", - "X, X_val = patches[:n_train, ...], patches[n_train:, ...]" - ] - }, - { - "cell_type": "markdown", - "id": "09ded741", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "As per usual, let's look at a training and validation patch to make sure everything looks okay." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5e2aa5a", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(14, 7))\n", - "plt.subplot(1, 2, 1)\n", - "plt.imshow(X[np.random.randint(X.shape[0]), ..., 0], cmap=\"gray_r\")\n", - "plt.title(\"Training patch\")\n", - "plt.subplot(1, 2, 2)\n", - "plt.imshow(X_val[np.random.randint(X_val.shape[0]), ..., 0], cmap=\"gray_r\")\n", - "plt.title(\"Validation patch\")" - ] - }, - { - "cell_type": "markdown", - "id": "9adf5aae", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 2: Configure and train the Noise2Void Network\n", - "\n", - "Noise2Void comes with a special config-object, where we store network-architecture and training specific parameters. See the docstring of the N2VConfig constructor for a description of all parameters.\n", - "\n", - "When creating the config-object, we provide the training data X. From X the library will extract mean and std that will be used to normalize all data before it is processed by the network.\n", - "\n", - "\n", - "Compared to supervised training (i.e. traditional CARE), we recommend to use N2V with an increased train_batch_size (e.g. 128) and batch_norm.\n", - "\n", - "To keep the network from learning the identity we have to manipulate the input pixels for the blindspot during training. How to exactly manipulate those values is controlled via the n2v_manipulator parameter with default value 'uniform_withCP' which samples a random value from the surrounding pixels, including the value at the control point. The size of the surrounding area can be configured via n2v_neighborhood_radius.\n", - "\n", - "The [paper supplement](https://arxiv.org/src/1811.10980v2/anc/supp_small.pdf) describes other pixel manipulators as well (section 3.1). If you want to configure one of those use the following values for n2v_manipulator:\n", - "* \"normal_additive\" for Gaussian (n2v_neighborhood_radius will set sigma)\n", - "* \"normal_fitted\" for Gaussian Fitting\n", - "* \"normal_withoutCP\" for Gaussian Pixel Selection\n", - "\n", - "For faster training multiple pixels per input patch can be manipulated. In our experiments we manipulated about 0.198% of the input pixels per patch. For a patch size of 64 by 64 pixels this corresponds to about 8 pixels. This fraction can be tuned via n2v_perc_pix.\n", - "\n", - "For Noise2Void training it is possible to pass arbitrarily large patches to the training method. From these patches random subpatches of size n2v_patch_shape are extracted during training. Default patch shape is set to (64, 64).\n", - "\n", - "In the past we experienced bleedthrough artifacts between channels if training was terminated to early. To counter bleedthrough we added the `single_net_per_channel` option, which is turned on by default. In the back a single U-Net for each channel is created and trained independently, thereby removing the possiblity of bleedthrough.
\n", - "Essentially the network gets multiplied by the number of channels, which increases the memory requirements. If your GPU gets too small, you can always split the channels manually and train a network for each channel one after another.\n", - "\n", - "---\n", - "

\n", - " TASK 3.2

\n", - "

\n", - "As suggested look at the docstring of the N2VConfig and then generate a configuration for your Noise2Void network, and choose a name to identify your model by.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ec86b39", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?N2VConfig" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "993e8ac2", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "config = N2VConfig()\n", - "vars(config)\n", - "model_name = \"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27000819", - "metadata": { - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# train_steps_per_epoch is set to (number of training patches)/(batch size), like this each training patch\n", - "# is shown once per epoch.\n", - "config = N2VConfig(\n", - " X,\n", - " unet_kern_size=3,\n", - " train_steps_per_epoch=int(X.shape[0] / 128),\n", - " train_epochs=200,\n", - " train_loss=\"mse\",\n", - " batch_norm=True,\n", - " train_batch_size=128,\n", - " n2v_perc_pix=0.198,\n", - " n2v_patch_shape=(64, 64),\n", - " n2v_manipulator=\"uniform_withCP\",\n", - " n2v_neighborhood_radius=5,\n", - ")\n", - "\n", - "# Let's look at the parameters stored in the config-object.\n", - "vars(config)\n", - "model_name = \"n2v_2D\"" - ] - }, - { - "cell_type": "markdown", - "id": "61203b93", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f41bab6", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# initialize the model\n", - "model = N2V(config, model_name, basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "35bf20d3", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Now let's train the model and monitor the progress in tensorboard.\n", - "Adapt the command below as you did before." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26940a2f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "%tensorboard --logdir=models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aaeb5c02", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "history = model.train(X, X_val)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a653ca4", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "print(sorted(list(history.history.keys())))\n", - "plt.figure(figsize=(16, 5))\n", - "plot_history(history, [\"loss\", \"val_loss\"])" - ] - }, - { - "cell_type": "markdown", - "id": "b520725d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Part 3: Prediction\n", - "\n", - "Similar to CARE a previously trained model is loaded by creating a new N2V-object without providing a `config`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a00aee95", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "model = N2V(config=None, name=model_name, basedir=\"models\")" - ] - }, - { - "cell_type": "markdown", - "id": "e1b5d86e", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's load a $1\\mu s$ scantime test images and denoise them using our network and like before we'll use the high SNR image to make a quantitative comparison. If you're using your own data and don't have an equivalent you can ignore that part." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c429e183", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "test_img = imread(\"data/SEM/test/test.tif\")[2, ...]\n", - "test_img_highSNR = imread(\"data/SEM/test/test.tif\")[-1, ...]\n", - "print(f\"Loaded test image with shape {test_img.shape}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28325ad3", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "test_denoised = model.predict(test_img, axes=\"YX\", n_tiles=(2, 1))" - ] - }, - { - "cell_type": "markdown", - "id": "84a3b87f", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's look at the results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a1ee796", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(30, 30))\n", - "plt.subplot(2, 3, 1)\n", - "plt.imshow(test_img, cmap=\"gray_r\")\n", - "plt.title(\"Noisy test image\")\n", - "plt.subplot(2, 3, 4)\n", - "plt.imshow(test_img[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.subplot(2, 3, 2)\n", - "plt.imshow(test_denoised, cmap=\"gray_r\")\n", - "plt.title(\"Denoised test image\")\n", - "plt.subplot(2, 3, 5)\n", - "plt.imshow(test_denoised[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.subplot(2, 3, 3)\n", - "plt.imshow(test_img_highSNR, cmap=\"gray_r\")\n", - "plt.title(\"High SNR image (4x5us)\")\n", - "plt.subplot(2, 3, 6)\n", - "plt.imshow(test_img_highSNR[2000:2200, 500:700], cmap=\"gray_r\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "561e5559", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "

\n", - " TASK 3.3

\n", - "

\n", - "\n", - "If you're using the SEM data (or happen to have a high SNR version of the image you predicted from) compare the structural similarity index and peak signal to noise ratio (wrt the high SNR image) of the noisy input image and the predicted image. If not, just skip this task.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a61bcb1f", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "ssi_input = ... # TODO\n", - "ssi_restored = ... # TODO\n", - "print(\n", - " f\"Structural similarity index (higher is better) wrt average of 4x5us images: \\n\"\n", - " f\"Input: {ssi_input} \\n\"\n", - " f\"Prediction: {ssi_restored}\"\n", - ")\n", - "psnr_input = ... # TODO\n", - "psnr_restored = ... # TODO\n", - "print(\n", - " f\"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\\n\"\n", - " f\"Input: {psnr_input} \\n\"\n", - " f\"Prediction: {psnr_restored}\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "54dc0b47", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "ssi_input = structural_similarity(test_img, test_img_highSNR, data_range=65535)\n", - "ssi_restored = structural_similarity(test_denoised, test_img_highSNR, data_range=65535)\n", - "print(\n", - " f\"Structural similarity index (higher is better) wrt average of 4x5us images: \\n\"\n", - " f\"Input: {ssi_input} \\n\"\n", - " f\"Prediction: {ssi_restored}\"\n", - ")\n", - "psnr_input = peak_signal_noise_ratio(test_img, test_img_highSNR, data_range=65535)\n", - "psnr_restored = peak_signal_noise_ratio(\n", - " test_denoised, test_img_highSNR, data_range=65535\n", - ")\n", - "print(\n", - " f\"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\\n\"\n", - " f\"Input: {psnr_input} \\n\"\n", - " f\"Prediction: {psnr_restored}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "8e5e97cc", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have reached the third checkpoint of this exercise! Please mark your progress in the course chat!\n", - "

\n", - "

\n", - " Consider sharing some pictures of your results on element, especially if you used your own data.\n", - "

\n", - "

\n", - " If there's still time, check out the bonus exercise.\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/solution_bonus1.ipynb b/solution_bonus1.ipynb deleted file mode 100644 index 3c75a5f..0000000 --- a/solution_bonus1.ipynb +++ /dev/null @@ -1,1188 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "b1b7576d", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "# Train Probabilistic Noise2Void\n", - "\n", - "Probabilistic Noise2Void, just as N2V, allows training from single noisy images.\n", - "\n", - "In order to get some additional quality squeezed out of your noisy input data, PN2V employs an additional noise model which can either be measured directly at your microscope or approximated by a process called ‘bootstrapping’.\n", - "Below we will give you a noise model for the first network to train and then bootstrap one, so you can apply PN2V to your own data if you'd like.\n", - "\n", - "Note: The PN2V implementation is written in pytorch, not Keras/TF.\n", - "\n", - "Note: PN2V experienced multiple updates regarding noise model representations. Hence, the [original PN2V repository](https://github.com/juglab/pn2v) is not any more the one we suggest to use (despite it of course working just as described in the original publication). So here we use the [PPN2V repo](https://github.com/juglab/PPN2V) which you installed during setup.\n", - "\n", - "
\n", - "Set your python kernel to 03_image_restoration_bonus\n", - "
\n", - "
\n", - "Make sure your previous notebook is shutdown to avoid running into GPU out-of-memory problems.\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a56c4a75", - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "import torch\n", - "\n", - "dtype = torch.float\n", - "device = torch.device(\"cuda:0\")\n", - "from torch.distributions import normal\n", - "import matplotlib.pyplot as plt, numpy as np, pickle\n", - "from scipy.stats import norm\n", - "from tifffile import imread\n", - "import sys\n", - "import os\n", - "import urllib\n", - "import zipfile" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ce2cb17", - "metadata": {}, - "outputs": [], - "source": [ - "from ppn2v.pn2v import histNoiseModel, gaussianMixtureNoiseModel\n", - "from ppn2v.pn2v.utils import plotProbabilityDistribution, PSNR\n", - "from ppn2v.unet.model import UNet\n", - "from ppn2v.pn2v import training, prediction" - ] - }, - { - "cell_type": "markdown", - "id": "e8f8283c", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "## Data Preperation\n", - "\n", - "Here we use a dataset of 2D images of fluorescently labeled membranes of Convallaria (lilly of the valley) acquired with a spinning disk microscope.\n", - "All 100 recorded images (1024×1024 pixels) show the same region of interest and only differ in their noise." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f62d2875", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "# Check that data download was successful\n", - "assert os.path.exists(\"data/Convallaria_diaphragm\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7c73978a", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"data/Convallaria_diaphragm/\"\n", - "data_name = \"convallaria\" # Name of the noise model\n", - "calibration_fn = \"20190726_tl_50um_500msec_wf_130EM_FD.tif\"\n", - "noisy_fn = \"20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif\"\n", - "noisy_imgs = imread(path + noisy_fn)\n", - "calibration_imgs = imread(path + calibration_fn)" - ] - }, - { - "cell_type": "markdown", - "id": "773f73ca", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "This notebook has a total of four options to generate a noise model for PN2V. You can pick which one you would like to use (and ignore the tasks in the options you don't wanna use)!\n", - "\n", - "There are two types of noise models for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM).\n", - "For both we need to provide a clean signal as groundtruth. For the dataset we have here we have calibration data available so you can choose between using the calibration data or bootstrapping the model by training a N2V network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78c9cfb5", - "metadata": {}, - "outputs": [], - "source": [ - "n_gaussian = 3 # Number of gaussians to use for Gaussian Mixture Model\n", - "n_coeff = 2 # No. of polynomial coefficients for parameterizing the mean, standard deviation and weight of Gaussian components." - ] - }, - { - "cell_type": "markdown", - "id": "dbfe7373", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Choice 1: Generate a Noise Model using Calibration Data\n", - "The noise model is a characteristic of your camera. The downloaded data folder contains a set of calibration images (For the Convallaria dataset, it is ```20190726_tl_50um_500msec_wf_130EM_FD.tif``` and the data to be denoised is named ```20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif```). We can either bin the noisy - GT pairs (obtained from noisy calibration images) as a 2-D histogram or fit a GMM distribution to obtain a smooth, parametric description of the noise model.\n", - "\n", - "We will use pairs of noisy calibration observations $x_i$ and clean signal $s_i$ (created by averaging these noisy, calibration images) to estimate the conditional distribution $p(x_i|s_i)$. Histogram-based and Gaussian Mixture Model-based noise models are generated and saved." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f08cf73", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "name_hist_noise_model_cal = \"_\".join([\"HistNoiseModel\", data_name, \"calibration\"])\n", - "name_gmm_noise_model_cal = \"_\".join(\n", - " [\"GMMNoiseModel\", data_name, str(n_gaussian), str(n_coeff), \"calibration\"]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b1b1ae65", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---\n", - "

\n", - " TASK 4.1

\n", - "

\n", - "\n", - "The calibration data contains 100 images of a static sample. Estimate the clean signal by averaging all the images.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d828180c", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "# Average the images in `calibration_imgs`\n", - "signal_cal = ... # TODO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec293abc", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# Average the images in `calibration_imgs`\n", - "signal_cal = np.mean(calibration_imgs[:, ...], axis=0)[np.newaxis, ...]" - ] - }, - { - "cell_type": "markdown", - "id": "96746d74", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Let's visualize a single image from the observation array alongside the average to see how the raw data compares to the pseudo ground truth signal." - ] - }, - { - "cell_type": "markdown", - "id": "2b50122c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d71a7778", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "plt.figure(figsize=(12, 12))\n", - "plt.subplot(1, 2, 2)\n", - "plt.title(label=\"average (ground truth)\")\n", - "plt.imshow(signal_cal[0], cmap=\"gray\")\n", - "plt.subplot(1, 2, 1)\n", - "plt.title(label=\"single raw image\")\n", - "plt.imshow(calibration_imgs[0], cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1456576b", - "metadata": {}, - "outputs": [], - "source": [ - "# The subsequent code expects the signal array to have a dimension for the samples\n", - "if signal_cal.shape == calibration_imgs.shape[1:]:\n", - " signal_cal = signal_cal[np.newaxis, ...]" - ] - }, - { - "cell_type": "markdown", - "id": "c690fc8f", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "There are two ways of generating a noise model for PN2V: creating a histogram of the noisy pixels based on the averaged GT or using a gaussian mixture model (GMM). You can pick which one you wanna use!\n", - "\n", - "
\n", - "\n", - "### Choice 1A: Creating the Histogram Noise Model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a histogram based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$.\n", - "\n", - "---\n", - "

\n", - " TASK 4.2

\n", - "

\n", - " Look at the docstring for createHistogram and use it to create a histogram based on the calibration data using the clean signal you created by averaging as groundtruth.

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fbd00eb2", - "metadata": {}, - "outputs": [], - "source": [ - "?histNoiseModel.createHistogram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb78b79e", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = ... # TODO\n", - "max_val = ... # TODO\n", - "# Create the histogram\n", - "histogram_cal = histNoiseModel.createHistogram(bins, ...) # TODO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0ad22ca", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = 234 # np.min(noisy_imgs)\n", - "max_val = 7402 # np.max(noisy_imgs)\n", - "print(\"min:\", min_val, \", max:\", max_val)\n", - "# Create the histogram\n", - "histogram_cal = histNoiseModel.createHistogram(\n", - " bins, min_val, max_val, calibration_imgs, signal_cal\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "5ea0dffb", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fc393b96", - "metadata": {}, - "outputs": [], - "source": [ - "# Saving histogram to disk.\n", - "np.save(path + name_hist_noise_model_cal + \".npy\", histogram_cal)\n", - "histogramFD_cal = histogram_cal[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4f920fcf", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the histogram-based noise model.\n", - "plt.xlabel(\"Observation Bin\")\n", - "plt.ylabel(\"Signal Bin\")\n", - "plt.imshow(histogramFD_cal**0.25, cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5993f09c", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Choice 1B: Creating the GMM noise model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "655c66f9", - "metadata": {}, - "outputs": [], - "source": [ - "min_signal = np.min(signal_cal)\n", - "max_signal = np.max(signal_cal)\n", - "print(\"Minimum Signal Intensity is\", min_signal)\n", - "print(\"Maximum Signal Intensity is\", max_signal)" - ] - }, - { - "cell_type": "markdown", - "id": "35722d03", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b056b9e6", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?gaussianMixtureNoiseModel.GaussianMixtureNoiseModel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ffb712e", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_cal = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel(\n", - " min_signal=min_signal,\n", - " max_signal=max_signal,\n", - " path=path,\n", - " weight=None,\n", - " n_gaussian=n_gaussian,\n", - " n_coeff=n_coeff,\n", - " min_sigma=50,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa8892fd", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_cal.train(\n", - " signal_cal,\n", - " calibration_imgs,\n", - " batchSize=250000,\n", - " n_epochs=2000,\n", - " learning_rate=0.1,\n", - " name=name_gmm_noise_model_cal,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "7305eeb0", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Visualizing the Histogram-based and GMM-based noise models\n", - "\n", - "This only works if you generated both a histogram (Choice 1A) and GMM-based (Choice 1B) noise model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d060c437", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plotProbabilityDistribution(\n", - " signalBinIndex=170,\n", - " histogram=histogramFD_cal,\n", - " gaussianMixtureNoiseModel=gmm_noise_model_cal,\n", - " min_signal=min_val,\n", - " max_signal=max_val,\n", - " n_bin=bins,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e63e2061", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## Choice 2: Generate a Noise Model by Bootstrapping\n", - "\n", - "Here we bootstrap a suitable histogram noise model and a GMM noise model after denoising the noisy images with Noise2Void and then using these denoised images as pseudo GT.\n", - "So first, we need to train a N2V model (now with pytorch) to estimate the conditional distribution $p(x_i|s_i)$. No additional calibration data is used for bootstrapping (so no need to use `calibration_imgs` or `singal_cal` again)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a4145cb", - "metadata": {}, - "outputs": [], - "source": [ - "model_name = data_name + \"_n2v\"\n", - "name_hist_noise_model_bootstrap = \"_\".join([\"HistNoiseModel\", data_name, \"bootstrap\"])\n", - "name_gmm_noise_model_bootstrap = \"_\".join(\n", - " [\"GMMNoiseModel\", data_name, str(n_gaussian), str(n_coeff), \"bootstrap\"]\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f076055e", - "metadata": {}, - "outputs": [], - "source": [ - "# Configure the Noise2Void network\n", - "n2v_net = UNet(1, depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d02b99c", - "metadata": {}, - "outputs": [], - "source": [ - "# Prepare training+validation data\n", - "train_data = noisy_imgs[:-5].copy()\n", - "val_data = noisy_imgs[-5:].copy()\n", - "np.random.shuffle(train_data)\n", - "np.random.shuffle(val_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2dfc50a3", - "metadata": {}, - "outputs": [], - "source": [ - "train_history, val_history = training.trainNetwork(\n", - " net=n2v_net,\n", - " trainData=train_data,\n", - " valData=val_data,\n", - " postfix=model_name,\n", - " directory=path,\n", - " noiseModel=None,\n", - " device=device,\n", - " numOfEpochs=200,\n", - " stepsPerEpoch=10,\n", - " virtualBatchSize=20,\n", - " batchSize=1,\n", - " learningRate=1e-3,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e7261ec", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the training and validation loss\n", - "plt.xlabel(\"epoch\")\n", - "plt.ylabel(\"loss\")\n", - "plt.plot(val_history, label=\"validation loss\")\n", - "plt.plot(train_history, label=\"training loss\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb119445", - "metadata": {}, - "outputs": [], - "source": [ - "# We now run the N2V model to create pseudo groundtruth.\n", - "n2v_result_imgs = []\n", - "n2v_input_imgs = []\n", - "\n", - "for index in range(noisy_imgs.shape[0]):\n", - " im = noisy_imgs[index]\n", - " # We are using tiling to fit the image into memory\n", - " # If you get an error try a smaller patch size (ps)\n", - " n2v_pred = prediction.tiledPredict(\n", - " im, n2v_net, ps=256, overlap=48, device=device, noiseModel=None\n", - " )\n", - " n2v_result_imgs.append(n2v_pred)\n", - " n2v_input_imgs.append(im)\n", - " if index % 10 == 0:\n", - " print(\"image:\", index)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fff6264f", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# In bootstrap mode, we estimate pseudo GT by using N2V denoised images.\n", - "signal_bootstrap = np.array(n2v_result_imgs)\n", - "# Let's look the raw data and our pseudo ground truth signal\n", - "print(signal_bootstrap.shape)\n", - "plt.figure(figsize=(12, 12))\n", - "plt.subplot(2, 2, 2)\n", - "plt.title(label=\"pseudo GT (generated by N2V denoising)\")\n", - "plt.imshow(signal_bootstrap[0], cmap=\"gray\")\n", - "plt.subplot(2, 2, 4)\n", - "plt.imshow(signal_bootstrap[0, -128:, -128:], cmap=\"gray\")\n", - "plt.subplot(2, 2, 1)\n", - "plt.title(label=\"single raw image\")\n", - "plt.imshow(noisy_imgs[0], cmap=\"gray\")\n", - "plt.subplot(2, 2, 3)\n", - "plt.imshow(noisy_imgs[0, -128:, -128:], cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fd230f12", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Now that we have pseudoGT, you can pick again between a histogram based noise model and a GMM noise model\n", - "\n", - "
\n", - "\n", - "### Choice 2A: Creating the Histogram Noise Model\n", - "\n", - "---\n", - "

\n", - " TASK 4.3

\n", - "

\n", - " If you've already done Task 4.2, this is very similar!\n", - " Look at the docstring for createHistogram and use it to create a histogram using the bootstraped signal you created from the N2V predictions.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88a4cbe7", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "?histNoiseModel.createHistogram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09b7ca76", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "###TODO###\n", - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = ... # TODO\n", - "max_val = ... # TODO\n", - "# Create the histogram\n", - "histogram_bootstrap = histNoiseModel.createHistogram(bins, ...) # TODO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "497ca2ff", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "# Define the parameters for the histogram creation\n", - "bins = 256\n", - "# Values falling outside the range [min_val, max_val] are not included in the histogram, so the values in the images you want to denoise should fall within that range\n", - "min_val = np.min(noisy_imgs)\n", - "max_val = np.max(noisy_imgs)\n", - "# Create the histogram\n", - "histogram_bootstrap = histNoiseModel.createHistogram(\n", - " bins, min_val, max_val, noisy_imgs, signal_bootstrap\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "69aff158", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ad8e6df1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Saving histogram to disk.\n", - "np.save(path + name_hist_noise_model_bootstrap + \".npy\", histogram_bootstrap)\n", - "histogramFD_bootstrap = histogram_bootstrap[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5ade612", - "metadata": {}, - "outputs": [], - "source": [ - "# Let's look at the histogram-based noise model\n", - "plt.xlabel(\"Observation Bin\")\n", - "plt.ylabel(\"Signal Bin\")\n", - "plt.imshow(histogramFD_bootstrap**0.25, cmap=\"gray\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "f6074610", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "### Choice 2B: Creating the GMM noise model\n", - "Using the raw pixels $x_i$, and our averaged GT $s_i$, we are now learning a GMM based noise model. It describes the distribution $p(x_i|s_i)$ for each $s_i$." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "57f33040", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "min_signal = np.percentile(signal_bootstrap, 0.5)\n", - "max_signal = np.percentile(signal_bootstrap, 99.5)\n", - "print(\"Minimum Signal Intensity is\", min_signal)\n", - "print(\"Maximum Signal Intensity is\", max_signal)" - ] - }, - { - "cell_type": "markdown", - "id": "d775b9a4", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "Iterating the noise model training for `n_epoch=2000` and `batchSize=250000` works the best for `Convallaria` dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43a50b02", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_bootstrap = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel(\n", - " min_signal=min_signal,\n", - " max_signal=max_signal,\n", - " path=path,\n", - " weight=None,\n", - " n_gaussian=n_gaussian,\n", - " n_coeff=n_coeff,\n", - " device=device,\n", - " min_sigma=50,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4611b54b", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "gmm_noise_model_bootstrap.train(\n", - " signal_bootstrap,\n", - " noisy_imgs,\n", - " batchSize=250000,\n", - " n_epochs=2000,\n", - " learning_rate=0.1,\n", - " name=name_gmm_noise_model_bootstrap,\n", - " lowerClip=0.5,\n", - " upperClip=99.5,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "aaa3f882", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "### Visualizing the Histogram-based and GMM-based noise models\n", - "\n", - "This only works if you generated both a histogram (Choice 2A) and GMM-based (Choice 2B) noise model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "993c6b8e", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "plotProbabilityDistribution(\n", - " signalBinIndex=170,\n", - " histogram=histogramFD_bootstrap,\n", - " gaussianMixtureNoiseModel=gmm_noise_model_bootstrap,\n", - " min_signal=min_val,\n", - " max_signal=max_val,\n", - " n_bin=bins,\n", - " device=device,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "89f86336", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## PN2V Training\n", - "\n", - "---\n", - "

\n", - " TASK 4.4

\n", - "

\n", - " Adapt to use the noise model of your choice here to then train PN2V with.\n", - "

\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0dffc131", - "metadata": {}, - "outputs": [], - "source": [ - "###TODO###\n", - "noise_model_type = \"gmm\" # pick: \"hist\" or \"gmm\"\n", - "noise_model_data = \"bootstrap\" # pick: \"calibration\" or \"bootstrap\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f67cc3dc", - "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] - }, - "outputs": [], - "source": [ - "if noise_model_type == \"hist\":\n", - " noise_model_name = \"_\".join([\"HistNoiseModel\", data_name, noise_model_data])\n", - " histogram = np.load(path + noise_model_name + \".npy\")\n", - " noise_model = histNoiseModel.NoiseModel(histogram, device=device)\n", - "elif noise_model_type == \"gmm\":\n", - " noise_model_name = \"_\".join(\n", - " [\"GMMNoiseModel\", data_name, str(n_gaussian), str(n_coeff), noise_model_data]\n", - " )\n", - " params = np.load(path + noise_model_name + \".npz\")\n", - " noise_model = gaussianMixtureNoiseModel.GaussianMixtureNoiseModel(\n", - " params=params, device=device\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "6bc7c3e9", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4fa867d1", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Create a network with 800 output channels that are interpreted as samples from the prior.\n", - "pn2v_net = UNet(800, depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43d6e350", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Start training.\n", - "trainHist, valHist = training.trainNetwork(\n", - " net=pn2v_net,\n", - " trainData=train_data,\n", - " valData=val_data,\n", - " postfix=noise_model_name,\n", - " directory=path,\n", - " noiseModel=noise_model,\n", - " device=device,\n", - " numOfEpochs=200,\n", - " stepsPerEpoch=5,\n", - " virtualBatchSize=20,\n", - " batchSize=1,\n", - " learningRate=1e-3,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "57b92b13", - "metadata": { - "cell_marker": "\"\"\"", - "lines_to_next_cell": 0 - }, - "source": [ - "
\n", - "\n", - "## PN2V Evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ae0bb6d", - "metadata": {}, - "outputs": [], - "source": [ - "test_data = noisy_imgs[\n", - " :, :512, :512\n", - "] # We are loading only a sub image to speed up computation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d074aee5", - "metadata": {}, - "outputs": [], - "source": [ - "# We estimate the ground truth by averaging.\n", - "test_data_gt = np.mean(test_data[:, ...], axis=0)[np.newaxis, ...]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6225e3d3", - "metadata": {}, - "outputs": [], - "source": [ - "pn2v_net = torch.load(path + \"/last_\" + noise_model_name + \".net\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb12628c", - "metadata": {}, - "outputs": [], - "source": [ - "# Now we are processing data and calculating PSNR values.\n", - "mmse_psnrs = []\n", - "prior_psnrs = []\n", - "input_psnrs = []\n", - "result_ims = []\n", - "input_ims = []\n", - "\n", - "# We iterate over all test images.\n", - "for index in range(test_data.shape[0]):\n", - " im = test_data[index]\n", - " gt = test_data_gt[0] # The ground truth is the same for all images\n", - "\n", - " # We are using tiling to fit the image into memory\n", - " # If you get an error try a smaller patch size (ps)\n", - " means, mse_est = prediction.tiledPredict(\n", - " im, pn2v_net, ps=192, overlap=48, device=device, noiseModel=noise_model\n", - " )\n", - "\n", - " result_ims.append(mse_est)\n", - " input_ims.append(im)\n", - "\n", - " range_psnr = np.max(gt) - np.min(gt)\n", - " psnr = PSNR(gt, mse_est, range_psnr)\n", - " psnr_prior = PSNR(gt, means, range_psnr)\n", - " input_psnr = PSNR(gt, im, range_psnr)\n", - " mmse_psnrs.append(psnr)\n", - " prior_psnrs.append(psnr_prior)\n", - " input_psnrs.append(input_psnr)\n", - "\n", - " print(\"image:\", index)\n", - " print(\"PSNR input\", input_psnr)\n", - " print(\"PSNR prior\", psnr_prior) # Without info from masked pixel\n", - " print(\"PSNR mse\", psnr) # MMSE estimate using the masked pixel\n", - " print(\"-----------------------------------\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "69438c2a", - "metadata": {}, - "outputs": [], - "source": [ - "?prediction.tiledPredict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9c27130", - "metadata": {}, - "outputs": [], - "source": [ - "# We display the results for the last test image\n", - "vmi = np.percentile(gt, 0.01)\n", - "vma = np.percentile(gt, 99)\n", - "\n", - "plt.figure(figsize=(15, 15))\n", - "plt.subplot(1, 3, 1)\n", - "plt.title(label=\"Input Image\")\n", - "plt.imshow(im, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "\n", - "plt.subplot(1, 3, 2)\n", - "plt.title(label=\"Avg. Prior\")\n", - "plt.imshow(means, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "\n", - "plt.subplot(1, 3, 3)\n", - "plt.title(label=\"PN2V-MMSE estimate\")\n", - "plt.imshow(mse_est, vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.show()\n", - "\n", - "plt.figure(figsize=(15, 15))\n", - "plt.subplot(1, 3, 1)\n", - "plt.title(label=\"Input Image\")\n", - "plt.imshow(im[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "plt.subplot(1, 3, 2)\n", - "plt.title(label=\"Avg. Prior\")\n", - "plt.imshow(means[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "plt.subplot(1, 3, 3)\n", - "plt.title(label=\"PN2V-MMSE estimate\")\n", - "plt.imshow(mse_est[100:200, 150:250], vmax=vma, vmin=vmi, cmap=\"magma\")\n", - "plt.axhline(y=50, linewidth=3, color=\"white\", alpha=0.5, ls=\"--\")\n", - "\n", - "\n", - "plt.figure(figsize=(15, 5))\n", - "plt.plot(im[150, 150:250], label=\"Input Image\")\n", - "plt.plot(means[150, 150:250], label=\"Avg. Prior\")\n", - "plt.plot(mse_est[150, 150:250], label=\"PN2V-MMSE estimate\")\n", - "plt.plot(gt[150, 150:250], label=\"Pseudo GT by averaging\")\n", - "plt.legend()\n", - "\n", - "plt.show()\n", - "print(\n", - " \"Avg PSNR Prior:\",\n", - " np.mean(np.array(prior_psnrs)),\n", - " \"+-(2SEM)\",\n", - " 2 * np.std(np.array(prior_psnrs)) / np.sqrt(float(len(prior_psnrs))),\n", - ")\n", - "print(\n", - " \"Avg PSNR MMSE:\",\n", - " np.mean(np.array(mmse_psnrs)),\n", - " \"+-(2SEM)\",\n", - " 2 * np.std(np.array(mmse_psnrs)) / np.sqrt(float(len(mmse_psnrs))),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "66930ec5", - "metadata": { - "cell_marker": "\"\"\"" - }, - "source": [ - "---\n", - "---\n", - "

\n", - " TASK 4.5

\n", - "

\n", - " Try PN2V for your own data! You probably don't have calibration data, but with the bootstrapping method you don't need any!\n", - "

\n", - "
\n", - "\n", - "---\n", - "\n", - "
\n", - "

\n", - " Congratulations!

\n", - "

\n", - " You have completed the bonus exercise!\n", - "

\n", - "
" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "all", - "main_language": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}