diff --git a/README.md b/README.md
index f7a0372..d13da2d 100644
--- a/README.md
+++ b/README.md
@@ -1,34 +1,7 @@
-# Exercise 4: Image translation
-## Setup
+# Exercise 6: Image translation
-Make sure that you are inside of the `image_translation` folder by using the `cd` command to change directories if needed.
+This exercise is split into two parts:
+- Virtual staining with a regression approach using a UNet [part_1](./part_1/)
+- Virtual staining with an generative approach using a GAN [part_2](./part_2/)
-Make sure that you can use mamba to switch environments.
-
-```bash
-mamba init
-```
-
-**Close your shell, and login again.**
-
-Run the setup script to create the environment for this exercise and download the dataset.
-```bash
-sh setup.sh
-```
-Activate your environment
-```bash
-mamba activate 04_image_translation
-```
-
-Launch a jupyter environment
-
-```
-jupyter notebook
-```
-
-...and continue with the instructions in the notebook.
-
-If 04_image_translation is not available as a kernel in jupyter, run
-```
-python -m ipykernel install --user --name=04_image_translation
-```
+Look into the directory for the part you will be working on (i.e `part_1` or `part_2`) for further installation and 'how to run' instructions.
diff --git a/exercise.ipynb b/exercise.ipynb
deleted file mode 100644
index 1ed75b5..0000000
--- a/exercise.ipynb
+++ /dev/null
@@ -1,1139 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "15d23751",
- "metadata": {
- "cell_marker": "\"\"\""
- },
- "source": [
- "# Image translation\n",
- "---\n",
- "\n",
- "Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco.\n",
- "\n",
- "In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. \n",
- "\n",
- "Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy).\n",
- "\n",
- "VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy.\n",
- "![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg)\n",
- "\n",
- "[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning\n",
- ". eLife](https://elifesciences.org/articles/55502).\n",
- "\n",
- "VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b5957320",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels.\n",
- "![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6c62db93",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "
\n",
- "The exercise is organized in 3 parts.\n",
- "\n",
- "* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.\n",
- "* Lunch break - The model will continue training during lunch.\n",
- "* **Part 2** - Evaluate the training with tensorboard. Train another model.\n",
- "* **Part 3** - Tune the models to improve performance.\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "30739bb5",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.\n",
- "\n",
- "\n",
- "Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~20 min on a typical AWS node. \n",
- "We will discuss your observations on google doc after checkpoints 2 and 3.\n",
- "\n",
- "The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "658e3b31",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "
\n",
- "Set your python kernel to 04_image_translation\n",
- "
\n",
- "\n",
- "### Task 1.1\n",
- " \n",
- "Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining.\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7f6f3609",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 1
- },
- "source": [
- "## Explore the effects of augmentation on batch.\n",
- "\n",
- "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n",
- " \n",
- "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n",
- "- `source`: the input image, a tensor of size 1*1*Y*X\n",
- "- `target`: the target image, a tensor of size 2*1*Y*X\n",
- "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1684d72e",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.2\n",
- "\n",
- "Setup the data loader and log several batches to tensorboard.\n",
- "\n",
- "Based on the tensorboard images, what are the two channels in the target image?\n",
- "\n",
- "Note: If tensorboard is not showing images, try refreshing and using the \"Images\" tab.\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "67211280",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define a function to write a batch to tensorboard log.\n",
- "\n",
- "def log_batch_tensorboard(batch, batchno, writer, card_name):\n",
- " \"\"\"\n",
- " Logs a batch of images to TensorBoard.\n",
- "\n",
- " Args:\n",
- " batch (dict): A dictionary containing the batch of images to be logged.\n",
- " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n",
- " card_name (str): The name of the card to be displayed in TensorBoard.\n",
- "\n",
- " Returns:\n",
- " None\n",
- " \"\"\"\n",
- " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
- " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- "\n",
- " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
- " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
- " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
- " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " [N, C, H, W] = batch_phase.shape\n",
- " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n",
- " interleaved_images[0::3, :] = batch_phase\n",
- " interleaved_images[1::3, :] = batch_nuclei\n",
- " interleaved_images[2::3, :] = batch_membrane\n",
- "\n",
- " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n",
- "\n",
- " # add the grid to tensorboard\n",
- " writer.add_image(card_name, grid, batchno)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "73577e5a",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define a function to visualize a batch on jupyter, in case tensorboard is finicky \n",
- "\n",
- "def log_batch_jupyter(batch):\n",
- " \"\"\"\n",
- " Logs a batch of images on jupyter using ipywidget.\n",
- "\n",
- " Args:\n",
- " batch (dict): A dictionary containing the batch of images to be logged.\n",
- "\n",
- " Returns:\n",
- " None\n",
- " \"\"\"\n",
- " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
- " batch_size = batch_phase.shape[0]\n",
- " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- "\n",
- " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
- " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
- " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
- " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " plt.figure()\n",
- " fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10))\n",
- " [N, C, H, W] = batch_phase.shape\n",
- " for sample_id in range(batch_size):\n",
- " axes[sample_id, 0].imshow(batch_phase[sample_id,0])\n",
- " axes[sample_id, 1].imshow(batch_nuclei[sample_id,0])\n",
- " axes[sample_id, 2].imshow(batch_membrane[sample_id,0])\n",
- "\n",
- " for i in range(n_channels):\n",
- " axes[sample_id, i].axis(\"off\")\n",
- " axes[sample_id, i].set_title(dataset.channel_names[i])\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "398f4546",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "\n",
- "# Initialize the data module.\n",
- "\n",
- "BATCH_SIZE = 4\n",
- "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n",
- "# More seriously, batch size does not have to be a power of 2.\n",
- "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n",
- "\n",
- "data_module = HCSDataModule(\n",
- " data_path,\n",
- " source_channel=\"Phase\",\n",
- " target_channel=[\"Nuclei\", \"Membrane\"],\n",
- " z_window_size=1,\n",
- " split_ratio=0.8,\n",
- " batch_size=BATCH_SIZE,\n",
- " num_workers=8,\n",
- " architecture=\"2D\",\n",
- " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n",
- " augment=False, # Turn off augmentation for now.\n",
- ")\n",
- "data_module.setup(\"fit\")\n",
- "\n",
- "print(\n",
- " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n",
- ")\n",
- "train_dataloader = data_module.train_dataloader()\n",
- "\n",
- "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n",
- "\n",
- "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
- "# Draw a batch and write to tensorboard.\n",
- "batch = next(iter(train_dataloader))\n",
- "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n",
- "writer.close()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a4c45450",
- "metadata": {},
- "source": [
- "Visualize directly on Jupyter ☄️, if your tensorboard is causing issues."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "73466d3f",
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib inline\n",
- "log_batch_jupyter(batch)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "19def8d6",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "## View augmentations using tensorboard."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97bdcbd8",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Here we turn on data augmentation and rerun setup\n",
- "data_module.augment = True\n",
- "data_module.setup(\"fit\")\n",
- "\n",
- "# get the new data loader with augmentation turned on\n",
- "augmented_train_dataloader = data_module.train_dataloader()\n",
- "\n",
- "# Draw batches and write to tensorboard\n",
- "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
- "augmented_batch = next(iter(augmented_train_dataloader))\n",
- "log_batch_tensorboard(augmented_batch, 0, writer, \"augmentation/some\")\n",
- "writer.close()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "247bf9c7",
- "metadata": {},
- "source": [
- "Visualize directly on Jupyter ☄️"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "de281de7",
- "metadata": {},
- "outputs": [],
- "source": [
- "log_batch_jupyter(augmented_batch)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "47cc91e4",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.3\n",
- "Can you tell what augmentation were applied from looking at the augmented images in Tensorboard?\n",
- "\n",
- "Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529).\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5d05336e",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "## Train a 2D U-Net model to predict nuclei and membrane from phase.\n",
- "\n",
- "### Construct a 2D U-Net\n",
- "See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "964d5aae",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "# Create a 2D UNet.\n",
- "GPU_ID = 0\n",
- "BATCH_SIZE = 10\n",
- "YX_PATCH_SIZE = (512, 512)\n",
- "\n",
- "\n",
- "# Dictionary that specifies key parameters of the model.\n",
- "phase2fluor_config = {\n",
- " \"architecture\": \"2D\",\n",
- " \"num_filters\": [24, 48, 96, 192, 384],\n",
- " \"in_channels\": 1,\n",
- " \"out_channels\": 2,\n",
- " \"residual\": True,\n",
- " \"dropout\": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.\n",
- " \"task\": \"reg\", # reg = regression task.\n",
- "}\n",
- "\n",
- "phase2fluor_model = VSUNet(\n",
- " model_config=phase2fluor_config.copy(),\n",
- " batch_size=BATCH_SIZE,\n",
- " loss_function=torch.nn.functional.l1_loss,\n",
- " schedule=\"WarmupCosine\",\n",
- " log_num_samples=5, # Number of samples from each batch to log to tensorboard.\n",
- " example_input_yx_shape=YX_PATCH_SIZE,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "eabb4902",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "### Instantiate data module and trainer, test that we are setup to launch training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9728a1c0",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "# Setup the data module.\n",
- "phase2fluor_data = HCSDataModule(\n",
- " data_path,\n",
- " source_channel=\"Phase\",\n",
- " target_channel=[\"Nuclei\", \"Membrane\"],\n",
- " z_window_size=1,\n",
- " split_ratio=0.8,\n",
- " batch_size=BATCH_SIZE,\n",
- " num_workers=8,\n",
- " architecture=\"2D\",\n",
- " yx_patch_size=YX_PATCH_SIZE,\n",
- " augment=True,\n",
- ")\n",
- "phase2fluor_data.setup(\"fit\")\n",
- "# fast_dev_run runs a single batch of data through the model to check for errors.\n",
- "trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n",
- "\n",
- "# trainer class takes the model and the data module as inputs.\n",
- "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b837c6b2",
- "metadata": {},
- "source": [
- "## View model graph.\n",
- "\n",
- "PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "31665d0f",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.4\n",
- "Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization?\n",
- "
\n",
- "\n",
- "### Task 1.5\n",
- "Start training by running the following cell. Check the new logs on the tensorboard.\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "26303693",
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "GPU_ID = 0\n",
- "n_samples = len(phase2fluor_data.train_dataset)\n",
- "steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n",
- "n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.\n",
- "\n",
- "trainer = VSTrainer(\n",
- " accelerator=\"gpu\",\n",
- " devices=[GPU_ID],\n",
- " max_epochs=n_epochs,\n",
- " log_every_n_steps=steps_per_epoch // 2,\n",
- " # log losses and image samples 2 times per epoch.\n",
- " logger=TensorBoardLogger(\n",
- " save_dir=log_dir,\n",
- " # lightning trainer transparently saves logs and model checkpoints in this directory.\n",
- " name=\"phase2fluor\",\n",
- " log_graph=True,\n",
- " ),\n",
- " ) \n",
- "# Launch training and check that loss and images are being logged on tensorboard.\n",
- "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4260177d",
- "metadata": {
- "cell_marker": "\"\"\""
- },
- "source": [
- "
\n",
- "\n",
- "## Checkpoint 1\n",
- "\n",
- "Now the training has started,\n",
- "we can come back after a while and evaluate the performance!\n",
- "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9f60283b",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0,
+ "title": ""
+ },
+ "source": [
+ "# Part 2: Load & Assess trained Pix2PixGAN using tensorboard, discuss performance of the model.\n",
+ "--------------------------------------------------\n",
+ "Learning goals:\n",
+ "- Load a pre-trained Pix2PixHD GAN model for either phase to nuclei.\n",
+ "- Discuss the loss components of Pix2PixHD GAN and how they are used to train the model.\n",
+ "- Evaluate the fit of the model on the train and validation datasets.\n",
+ "\n",
+ "In this part, we will evaluate the performance of the pre-trained model. We will begin by looking qualitatively at the model predictions, then dive into the different loss plots. We will discuss the implications of different hyper-parameter combinations for the performance of the model.\n",
+ "\n",
+ "If you are having issues loading the tensorboard session click \"Launch TensorBoard session\". You should then be able to add the log_dir path below and a tensorboard session shouls then load."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "193d455c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "log_dir = f\"{top_dir}/model_tensorboard/{opt.name}/\"\n",
+ "%reload_ext tensorboard\n",
+ "%tensorboard --logdir $log_dir"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4d98c0dc",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "## Training Results\n",
+ "Please note down your thoughts about the following questions...\n",
+ "\n",
+ "**- What do you notice about the virtual staining predictions? How do they appear compared to the regression-based approach? Can you spot any hallucinations?**
\n",
+ "**- What do you notice about the probabilities of the discriminators? How do the values compare during training compared to validation?**
\n",
+ "**- What do you notice about the feature matching L1 loss?**
\n",
+ "**- What do you notice about the least-square loss?**
\n",
+ " \n",
+ "## Checkpoint 2\n",
+ "Congratulations! You should now have a better understanding the different loss components of Pix2PixHD GAN and how they are used to train the model. You should also have a good understanding of the fit of the model during training on the training and validation datasets.\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c62a3d5",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "# Part 3: Evaluate performance of the virtual staining on unseen data.\n",
+ "--------------------------------------------------\n",
+ "## Evaluate the performance of the model.\n",
+ "We now look at the same metrics of performance of the previous model. We typically evaluate the model performance on a held out test data. \n",
+ "\n",
+ "Steps:\n",
+ "- Define our model parameters for the pre-trained model (these are the same parameters as shown in earlier cells but copied here for clarity).\n",
+ "- Load the test data.\n",
+ "\n",
+ "We will first load the test data using the same format as the training and validation data. We will then use the model to sample a virtual nuclei staining soltuion from the phase image. We will then evaluate the performance of the model using the following metrics:\n",
+ "\n",
+ "Pixel-level metrics:\n",
+ "- [Peak-Signal-to-Noise-Ratio (PSNR)](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio).\n",
+ "- [Structural Similarity Index Measure (SSIM)](https://en.wikipedia.org/wiki/Structural_similarity).\n",
+ "- [Pearson Correlation Coefficient (PCC)](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).\n",
+ "\n",
+ "Instance-level metrics via [Cellpose masks](https://cellpose.org/):\n",
+ "- [Accuracy](https://en.wikipedia.org/wiki/Accuracy_and_precision#In_binary_classification)\n",
+ "- [Jaccard Index](https://en.wikipedia.org/wiki/Jaccard_index)\n",
+ "- [Dice Score](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)\n",
+ "- [Mean Average Precision](https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d38b033f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "opt = TestOptions().parse(save=False)\n",
+ "\n",
+ "# Define the parameters for the dataset.\n",
+ "opt.dataroot = output_image_folder\n",
+ "opt.data_type = 16 # Data type of the images.\n",
+ "opt.loadSize = 512 # Size of the loaded phase image.\n",
+ "opt.input_nc = 1 # Number of input channels.\n",
+ "opt.output_nc = 1 # Number of output channels.\n",
+ "opt.target = \"nuclei\" # \"nuclei\" or \"cyto\" depending on your choice of target for virtual stain.\n",
+ "opt.resize_or_crop = \"none\" # Scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none].\n",
+ "opt.batchSize = 1 # Batch size for training\n",
+ "\n",
+ "# Define the model parameters for the pre-trained model.\n",
+ "\n",
+ "# Define the parameters for the Generator.\n",
+ "opt.ngf = 64 # Number of filters in the generator.\n",
+ "opt.n_downsample_global = 4 # Number of downsampling layers in the generator.\n",
+ "opt.n_blocks_global = 9 # Number of residual blocks in the generator.\n",
+ "opt.n_blocks_local = 3 # Number of residual blocks in the generator.\n",
+ "opt.n_local_enhancers = 1 # Number of local enhancers in the generator.\n",
+ "\n",
+ "# Define the parameters for the Discriminators.\n",
+ "opt.num_D = 3 # Number of discriminators.\n",
+ "opt.n_layers_D = 3 # Number of layers in the discriminator.\n",
+ "opt.ndf = 32 # Number of filters in the discriminator.\n",
+ "\n",
+ "# Define general training parameters.\n",
+ "opt.gpu_ids= [0] # GPU ids to use.\n",
+ "opt.norm = \"instance\" # Normalization layer in the generator.\n",
+ "opt.use_dropout = \"\" # Use dropout in the generator (fixed at 0.2).\n",
+ "opt.batchSize = 8 # Batch size.\n",
+ "\n",
+ "# Define loss functions.\n",
+ "opt.no_vgg_loss = \"\" # Turn off VGG loss\n",
+ "opt.no_ganFeat_loss = \"\" # Turn off feature matching loss\n",
+ "opt.no_lsgan = \"\" # Turn off least square loss\n",
+ "\n",
+ "# Additional Inference parameters\n",
+ "opt.name = f\"dlmbl_vsnuclei\"\n",
+ "opt.how_many = 112 # Number of images to generate.\n",
+ "opt.checkpoints_dir = f\"{top_dir}/model_weights/\" # Path to the model checkpoints.\n",
+ "opt.results_dir = f\"{top_dir}/GAN_code/GANs_MI2I/pre_trained/{opt.name}/inference_results/\" # Path to store the results.\n",
+ "opt.which_epoch = \"latest\" # or specify the epoch number \"40\"\n",
+ "opt.phase = \"test\"\n",
+ "\n",
+ "opt.nThreads = 1 # test code only supports nThreads = 1\n",
+ "opt.batchSize = 1 # test code only supports batchSize = 1\n",
+ "opt.serial_batches = True # no shuffle\n",
+ "opt.no_flip = True # no flip\n",
+ "Path(opt.results_dir).mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "# Load the test data.\n",
+ "test_data_loader = CreateDataLoader(opt)\n",
+ "test_dataset = test_data_loader.load_data()\n",
+ "visualizer = Visualizer(opt)\n",
+ "print(f\"Total Test Images = {len(test_data_loader)}\")\n",
+ "# Load pre-trained model\n",
+ "model = create_model(opt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dde67288",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Generate & save predictions in the results directory.\n",
+ "inference_model(test_dataset, opt, model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "95db1b7e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Gather results for evaluation\n",
+ "virtual_stain_paths = sorted([i for i in Path(opt.results_dir).glob(\"**/*.tiff\")])\n",
+ "target_stain_paths = sorted([i for i in Path(f\"{output_image_folder}/{translation_task}/test/\").glob(\"**/*.tiff\")])\n",
+ "phase_paths = sorted([i for i in Path(f\"{output_image_folder}/input/test/\").glob(\"**/*.tiff\")])\n",
+ "assert (len(virtual_stain_paths) == len(target_stain_paths) == len(phase_paths)\n",
+ "), f\"Number of images do not match. {len(virtual_stain_paths)},{len(target_stain_paths)} {len(phase_paths)} \"\n",
+ "\n",
+ "# Create arrays to store the images.\n",
+ "virtual_stains = np.zeros((len(virtual_stain_paths), 512, 512))\n",
+ "target_stains = virtual_stains.copy()\n",
+ "phase_images = virtual_stains.copy()\n",
+ "# Load the images and store them in the arrays.\n",
+ "for index, (v_path, t_path, p_path) in tqdm(\n",
+ " enumerate(zip(virtual_stain_paths, target_stain_paths, phase_paths))\n",
+ "):\n",
+ " virtual_stain = imread(v_path)\n",
+ " phase_image = imread(p_path)\n",
+ " target_stain = imread(t_path)\n",
+ " # Append the images to the arrays.\n",
+ " phase_images[index] = phase_image\n",
+ " target_stains[index] = target_stain\n",
+ " virtual_stains[index] = virtual_stain"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d168855d",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0,
+ "tags": []
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "### Task 3.1 Visualise the results of the model on the test set.\n",
+ "\n",
+ "Create a matplotlib plot that visalises random samples of the phase images, target stains, and virtual stains.\n",
+ "If you can incorporate the crop function below to zoom in on the images that would be great!\n",
+ "
\n",
+ "\n",
+ "### Task 3.2 Compute pixel-level metrics\n",
+ "\n",
+ "Compute the pixel-level metrics for the virtual stains and target stains.\n",
+ "\n",
+ "The following code will compute the following:\n",
+ "- the pixel-based metrics (Pearson correlation, SSIM, PSNR) for the virtual stains and target stains.\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7ea32c4c",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Define the function to perform minmax normalization which is required for the pixel-level metrics.\n",
+ "def min_max_scale(input):\n",
+ " return (input - np.min(input)) / (np.max(input) - np.min(input))\n",
+ "\n",
+ "# Create a dataframe to store the pixel-level metrics.\n",
+ "test_pixel_metrics = pd.DataFrame(\n",
+ " columns=[\"model\", \"fov\",\"pearson_nuc\", \"ssim_nuc\", \"psnr_nuc\"]\n",
+ ")\n",
+ "\n",
+ "# Compute the pixel-level metrics.\n",
+ "for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains))):\n",
+ " fov = str(virtual_stain_paths[i]).split(\"/\")[-1].split(\".\")[0]\n",
+ " minmax_norm_target = min_max_scale(target_stain)\n",
+ " minmax_norm_predicted = min_max_scale(predicted_stain)\n",
+ " \n",
+ " # Compute SSIM\n",
+ " ssim_nuc = metrics.structural_similarity(\n",
+ " minmax_norm_target, minmax_norm_predicted, data_range=1\n",
+ " )\n",
+ " # Compute Pearson correlation\n",
+ " pearson_nuc = np.corrcoef(\n",
+ " minmax_norm_target.flatten(), minmax_norm_predicted.flatten()\n",
+ " )[0, 1]\n",
+ " # Compute PSNR\n",
+ " psnr_nuc = metrics.peak_signal_noise_ratio(\n",
+ " minmax_norm_target, minmax_norm_predicted, data_range=1\n",
+ " )\n",
+ " \n",
+ " test_pixel_metrics.loc[len(test_pixel_metrics)] = {\n",
+ " \"model\": \"pix2pixHD\",\n",
+ " \"fov\":fov,\n",
+ " \"pearson_nuc\": pearson_nuc,\n",
+ " \"ssim_nuc\": ssim_nuc,\n",
+ " \"psnr_nuc\": psnr_nuc, \n",
+ " }\n",
+ " \n",
+ "test_pixel_metrics.boxplot(\n",
+ " column=[\"pearson_nuc\", \"ssim_nuc\"],\n",
+ " rot=30,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e09b3843",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "test_pixel_metrics.boxplot(\n",
+ " column=[\"psnr_nuc\"],\n",
+ " rot=30,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "af1a13e0",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "test_pixel_metrics.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "83bbcee9",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "## Inference Pixel-level Results\n",
+ "Please note down your thoughts about the following questions...\n",
+ "\n",
+ "- What do these metrics tells us about the performance of the model?\n",
+ "- How do the pixel-level metrics compare to the regression-based approach?\n",
+ "- Could these metrics be skewed by the presence of hallucinations or background pilxels in the virtual stains?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "74de240b",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "
\n",
+ "\n",
+ "### Task 3.3 Compute instance-level metrics\n",
+ "\n",
+ "- Compute the instance-level metrics for the virtual stains and target stains.\n",
+ "- Instance metrics include the accuracy (average correct predictions with 0.5 threshold), jaccard index (intersection over union (IoU)) dice score (2x intersection over union), mean average precision, mean average precision at 50% IoU, mean average precision at 75% IoU, and mean average recall at 100% IoU.\n",
+ "\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "25bfb934",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Use the same function as previous part to extract the nuclei masks from pre-trained cellpose model.\n",
+ "def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.ShortTensor]:\n",
+ " # NOTE these are hardcoded for this notebook and A549 dataset\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " cp_nuc_kwargs = {\n",
+ " \"diameter\": 65,\n",
+ " \"cellprob_threshold\": 0.0, \n",
+ " }\n",
+ " cellpose_model = models.CellposeModel(\n",
+ " gpu=True, model_type='nuclei', device=torch.device(device)\n",
+ " )\n",
+ " pred_label, _, _ = cellpose_model.eval(prediction, **cp_nuc_kwargs)\n",
+ " target_label, _, _ = cellpose_model.eval(target, **cp_nuc_kwargs)\n",
+ "\n",
+ " pred_label = pred_label.astype(np.int32)\n",
+ " target_label = target_label.astype(np.int32)\n",
+ " pred_label = torch.ShortTensor(pred_label)\n",
+ " target_label = torch.ShortTensor(target_label)\n",
+ "\n",
+ " return (pred_label,target_label)\n",
+ "\n",
+ "# Define dataframe to store the segmentation metrics.\n",
+ "test_segmentation_metrics= pd.DataFrame(\n",
+ " columns=[\"model\", \"fov\",\"masks_per_fov\",\"accuracy\",\"dice\",\"jaccard\",\"mAP\",\"mAP_50\",\"mAP_75\",\"mAR_100\"]\n",
+ ")\n",
+ "# Define tuple to store the segmentation results. Each value in the tuple is a dictionary containing the model name, fov, predicted label, predicted stain, target label, and target stain.\n",
+ "segmentation_results = ()\n",
+ "\n",
+ "for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains))):\n",
+ " fov = str(virtual_stain_paths)[i].spilt(\"/\")[-1].split(\".\")[0]\n",
+ " minmax_norm_target = min_max_scale(target_stain)\n",
+ " minmax_norm_predicted = min_max_scale(predicted_stain)\n",
+ " # Compute the segmentation masks.\n",
+ " pred_label, target_label = cellpose_segmentation(minmax_norm_predicted, minmax_norm_target)\n",
+ " # Binary labels\n",
+ " pred_label_binary = pred_label > 0\n",
+ " target_label_binary = target_label > 0\n",
+ "\n",
+ " # Use Coco metrics to get mean average precision\n",
+ " coco_metrics = mean_average_precision(pred_label, target_label)\n",
+ " # Find unique number of labels\n",
+ " num_masks_fov = len(np.unique(pred_label))\n",
+ " # Find unique number of labels\n",
+ " num_masks_fov = len(np.unique(pred_label))\n",
+ " # Compute the segmentation metrics.\n",
+ " test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {\n",
+ " \"model\": \"pix2pixHD\",\n",
+ " \"fov\":fov,\n",
+ " \"masks_per_fov\": num_masks_fov,\n",
+ " \"accuracy\": accuracy(pred_label_binary, target_label_binary, task=\"binary\").item(),\n",
+ " \"dice\": dice(pred_label_binary, target_label_binary).item(),\n",
+ " \"jaccard\": jaccard_index(pred_label_binary, target_label_binary, task=\"binary\").item(),\n",
+ " \"mAP\":coco_metrics[\"map\"].item(),\n",
+ " \"mAP_50\":coco_metrics[\"map_50\"].item(),\n",
+ " \"mAP_75\":coco_metrics[\"map_75\"].item(),\n",
+ " \"mAR_100\":coco_metrics[\"mar_100\"].item()\n",
+ " }\n",
+ " # Store the segmentation results.\n",
+ " segmentation_result = {\n",
+ " \"model\": \"pix2pixHD\",\n",
+ " \"fov\":fov,\n",
+ " \"phase_image\": phase_images[i],\n",
+ " \"pred_label\": pred_label,\n",
+ " \"pred_stain\": predicted_stain,\n",
+ " \"target_label\": target_label,\n",
+ " \"target_stain\": target_stain,\n",
+ " }\n",
+ " segmentation_results += (segmentation_result,)\n",
+ "\n",
+ "test_segmentation_metrics.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3c19de9e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define function to visualize the segmentation results.\n",
+ "def visualise_results_and_masks(segmentation_results: Tuple[dict], segmentation_metrics: pd.DataFrame, rows: int = 5, crop_size: int = None, crop_type: str = 'center'):\n",
+ "\n",
+ " # Sample a subset of the segmentation results.\n",
+ " sample_indices = np.random.choice(len(phase_images),rows)\n",
+ " print(sample_indices)\n",
+ " segmentation_metrics = segmentation_metrics.iloc[sample_indices,:]\n",
+ " segmentation_results = [segmentation_results[i] for i in sample_indices]\n",
+ " # Define the figure and axes.\n",
+ " fig, axes = plt.subplots(rows, 5, figsize=(rows*3, 15))\n",
+ "\n",
+ " # Visualize the segmentation results.\n",
+ " for i in range(len((segmentation_results))):\n",
+ " segmentation_metric = segmentation_metrics.iloc[i]\n",
+ " result = segmentation_results[i]\n",
+ " phase_image = result[\"phase_image\"]\n",
+ " target_stain = result[\"target_stain\"]\n",
+ " target_label = result[\"target_label\"]\n",
+ " pred_stain = result[\"pred_stain\"]\n",
+ " pred_label = result[\"pred_label\"]\n",
+ " # Crop the images if required. Zoom into instances\n",
+ " if crop_size is not None:\n",
+ " phase_image = crop(phase_image, crop_size, crop_type)\n",
+ " target_stain = crop(target_stain, crop_size, crop_type)\n",
+ " target_label = crop(target_label, crop_size, crop_type)\n",
+ " pred_stain = crop(pred_stain, crop_size, crop_type)\n",
+ " pred_label = crop(pred_label, crop_size, crop_type)\n",
+ " \n",
+ " axes[i, 0].imshow(phase_image, cmap=\"gray\")\n",
+ " axes[i, 0].set_title(\"Phase\")\n",
+ " axes[i, 1].imshow(\n",
+ " target_stain,\n",
+ " cmap=\"gray\",\n",
+ " vmin=np.percentile(target_stain, 1),\n",
+ " vmax=np.percentile(target_stain, 99),\n",
+ " )\n",
+ " axes[i, 1].set_title(\"Target Fluorescence\")\n",
+ " axes[i, 2].imshow(pred_stain, cmap=\"gray\")\n",
+ " axes[i, 2].set_title(\"Virtual Stain\")\n",
+ " axes[i, 3].imshow(target_label, cmap=\"inferno\")\n",
+ " axes[i, 3].set_title(\"Target Fluorescence Mask\")\n",
+ " axes[i, 4].imshow(pred_label, cmap=\"inferno\")\n",
+ " # Add Metric values to the title\n",
+ " axes[i, 4].set_title(f\"Virtual Stain Mask\\nAcc:{segmentation_metric['accuracy']:.2f} Dice:{segmentation_metric['dice']:.2f}\\nJaccard:{segmentation_metric['jaccard']:.2f} MAP:{segmentation_metric['mAP']:.2f}\")\n",
+ " # Turn off the axes.\n",
+ " for ax in axes.flatten():\n",
+ " ax.axis(\"off\")\n",
+ "\n",
+ " plt.tight_layout()\n",
+ " plt.show()\n",
+ " \n",
+ "visualise_results_and_masks(segmentation_results,test_segmentation_metrics, crop_size=256, crop_type='center')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f11e65a",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "488b6367",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "## Inference Instance-level Results\n",
+ "Please note down your thoughts about the following questions...\n",
+ "\n",
+ "- What do these metrics tells us about the performance of the model?\n",
+ "- How does the performance compare to when looking at pixel-level metrics?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9395a312",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "
\n",
+ " \n",
+ "## Checkpoint 3\n",
+ "\n",
+ "Congratulations! You have generated predictions from a pre-trained model and evaluated the performance of the model on unseen data. You have computed pixel-level metrics and instance-level metrics to evaluate the performance of the model. You may have also began training your own Pix2PixHD GAN models with alternative hyperparameters.\n",
+ "\n",
+ "
\n",
+ " \n",
+ "## Checkpoint 4\n",
+ "\n",
+ "Congratulations! You should now have a better understanding of the difference in performance for image translation when approaching the problem using a regression vs. generative modelling approaches!\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ec2331fd",
+ "metadata": {
+ "cell_marker": "\"\"\"",
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "# Part 5: BONUS: Sample different virtual staining solutions from the GAN using MC-Dropout and explore the uncertainty in the virtual stain predictions.\n",
+ "--------------------------------------------------\n",
+ "Steps:\n",
+ "- Load the pre-trained model.\n",
+ "- Generate multiple predictions for the same input image.\n",
+ "- Compute the pixel-wise variance across the predictions.\n",
+ "- Visualise the pixel-wise variance to explore the uncertainty in the virtual stain predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1ea346ff",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use the same model and dataloaders as before.\n",
+ "# Load the test data.\n",
+ "test_data_loader = CreateDataLoader(opt)\n",
+ "test_dataset = test_data_loader.load_data()\n",
+ "visualizer = Visualizer(opt)\n",
+ "\n",
+ "# Load pre-trained model\n",
+ "opt.variational_inf_runs = 100 # Number of samples per phase input\n",
+ "opt.variation_inf_path = f\"./GAN_code/GANs_MI2I/pre_trained/{opt.name}/samples/\" # Path to store the samples.\n",
+ "opt.results_dir = f\"{top_dir}/GAN_code/GANs_MI2I/pre_trained/dlmbl_vsnuclei/sampling_results\"\n",
+ "opt.dropout_variation_inf = True # Use dropout during inference.\n",
+ "model = create_model(opt)\n",
+ "# Generate & save predictions in the variation_inf_path directory.\n",
+ "sampling(test_dataset, opt, model)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d359ea46",
+ "metadata": {
+ "lines_to_next_cell": 1
+ },
+ "outputs": [],
+ "source": [
+ "# Visualise Samples \n",
+ "samples = sorted([i for i in Path(f\"{opt.results_dir}\").glob(\"**/*.tif*\")])\n",
+ "assert len(samples) == 5\n",
+ "# Create arrays to store the images.\n",
+ "sample_images = np.zeros((5,100, 512, 512)) # (samples, images, height, width)\n",
+ "# Load the images and store them in the arrays.\n",
+ "for index, sample_path in tqdm(enumerate(samples)):\n",
+ " sample_image = imread(sample_path)\n",
+ " # Append the images to the arrays.\n",
+ " sample_images[index] = sample_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "26c430fc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a matplotlib plot with animation through images.\n",
+ "def animate_images(images):\n",
+ " # Expecting images to have shape (frames, height, width)\n",
+ " fig, ax = plt.subplots()\n",
+ " ax.axis('off')\n",
+ " \n",
+ " # Make sure images are in (frames, height, width) order\n",
+ " images = images.transpose(0, 2, 1) if images.shape[1] == images.shape[2] else images\n",
+ " \n",
+ " imgs = []\n",
+ " for i in range(min(100, len(images))): # Ensure you don't exceed the number of frames\n",
+ " im = ax.imshow(images[i], animated=True)\n",
+ " imgs.append([im])\n",
+ "\n",
+ " ani = animation.ArtistAnimation(fig, imgs, interval=100, blit=False, repeat_delay=1000)\n",
+ " \n",
+ " # Display the animation\n",
+ " # plt.close(fig)\n",
+ " display(HTML(ani.to_jshtml()))\n",
+ "\n",
+ "# Example call with sample_images[0]\n",
+ "animate_images(sample_images[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c2e73539",
+ "metadata": {
+ "cell_marker": "\"\"\""
+ },
+ "source": [
+ "
\n",
+ " \n",
+ "## Checkpoint 5\n",
+ "\n",
+ "Congratulations! This is the end of the conditional generative modelling approach to image translation notebook. You have trained and examined the loss components of Pix2PixHD GAN. You have compared the results of a regression-based approach vs. generative modelling approach and explored the variability in virtual staining solutions. I hope you have enjoyed learning experience!\n",
+ "
"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "all",
+ "main_language": "python"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/part_2/solution.py b/part_2/solution.py
new file mode 100644
index 0000000..498b9f6
--- /dev/null
+++ b/part_2/solution.py
@@ -0,0 +1,927 @@
+# %% [markdown]
+"""
+# A Generative Modelling Approach to Image translation
+Written by [Samuel Tonks](https://github.com/Tonks684), Krull Lab University of Birmingham UK, with many inputs and bugfixes from [Eduardo Hirata-Miyasaki](https://github.com/edyoshikun), [Ziwen Liu](https://github.com/ziw-liu) and [Shalin Mehta](https://github.com/mattersoflight) of CZ Biohub San Francisco.
+"""
+
+# %% [markdown]
+"""
+## Overview
+
+In part 2 of the image_translation exercise, we will predict fluorescence images of nuclei markers only from quantitative phase images of cells, using a specific type of generative model called a Conditional Generative Adversarial Network (cGAN). In contrast to a regression-based approach, cGANs learn to map from the phase contrast domain to a distirbution of virtual staining solutions. In this work we will utilise the [Pix2PixHD GAN](https://arxiv.org/abs/1711.11585) used in our recent [virtual staining works](https://ieeexplore.ieee.org/abstract/document/10230501?casa_token=NEyrUDqvFfIAAAAA:tklGisf9BEKWVjoZ6pgryKvLbF6JyurOu5Jrgoia1QQLpAMdCSlP9gMa02f3w37PvVjdiWCvFhA). For more details on the architecture and loss components of cGANs and Pix2PixHD GAN please see the READ.me.
+
+During this exercise will assess the different loss components of a pre-trained Pix2PixHD for the virtual nuclei staining task. We will then evaluate the performance of the model on unseen data using the same pixel-level and instance-level metrics as in the previous section. We will compare the performance of the Pix2PixHD GAN with the regression-based model Viscy. Finally, as a bonus, we will explore the variability and uncertainty in the virtual stain predictions using [MC-Dropout](https://arxiv.org/abs/1506.02142).
+
+## References
+- [Wang, T. et al. (2018) High-resolution image synthesis and semantic manipulation with conditional GANs](https://arxiv.org/abs/1711.11585)
+- [Tonks, S. et al. (2023) Evaluating virtual staining for high-throughput screening](https://ieeexplore.ieee.org/abstract/document/10230501?casa_token=NEyrUDqvFfIAAAAA:tklGisf9BEKWVjoZ6pgryKvLbF6JyurOu5Jrgoia1QQLpAMdCSlP9gMa02f3w37PvVjdiWCvFhA)
+- [Gal, Y. et al. (2016) Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning](https://arxiv.org/abs/1506.02142)
+
+"""
+# %% [markdown]
+"""
+## Goals
+
+This part of the exercise is organized in 5 parts.
+
+As you have already explored the data in the previous parts, we will focus on training and evaluating Pix2PixHD GAN. The parts are as follows:
+
+* **Part 1** - Define dataloaders & walk through steps to train a Pix2PixHD GAN.
+* **Part 2** - Load and assess a pre-trained Pix2PixGAN using tensorboard, discuss the different loss components and how new hyper-parameter configurations could impact performance.
+* **Part 3** - Evaluate performance of pre-trained Pix2PixGAN using pixel-level and instance-level metrics.
+* **Part 4** - Compare the performance of Viscy (regression-based) with Pix2PixHD GAN (generative modelling approach)
+* **Part 5** - *BONUS*: Sample different virtual staining solutions from the Pix2PixHD GAN using [MC-Dropout](https://arxiv.org/abs/1506.02142) and explore the variability and subsequent uncertainty in the virtual stain predictions.
+
+"""
+
+
+# %% [markdown]
+"""
+
+Set your python kernel to 06_image_translation
+
+
+"""
+# %% [markdown]
+"""
+If you have issues with getting the kernel to load please follow the following steps:
+1. Ctrl+Shift+P to open Command Palette. Type Python: Select Interpreter and select 06_image_translation.
+2. Register the environment as a Kernel using the below line of code.
+3. Reload VS Code via Ctrl+Shift+P, then select Reload Window.
+"""
+#%%
+# !python -m ipykernel install --user --name 06_image_translation --display-name "Python 06_image_translation"
+
+# %% [markdown]
+"""
+# Part 1: Define dataloaders & walk through steps to train a Pix2PixHD GAN.
+---------
+The focus of this part of the exercise is on understanding a generative modelling approach to image translation, how to train and evaluate a cGAN, and explore some hyperparameters of the cGAN.
+
+Learning goals:
+
+- Load dataset and configure dataloader.
+- Configure Pix2PixHD GAN to train for translating from phase to nuclei.
+
+
+Before we start please set the first section of the parent_dir to your personal path
+"""
+# %%
+# TO DO: Change the path to the directory where the data and code is stored is stored.
+import os
+import sys
+parent_dir = os.path.abspath("ADD_HERE/data/06_image_translation/part2/GAN_code/GANs_MI2I/")
+sys.path.append(parent_dir)
+
+# %%
+from pathlib import Path
+import torch
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from skimage import metrics
+from tifffile import imread, imsave
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+from IPython.display import HTML
+
+from cellpose import models
+from typing import List, Tuple
+from numpy.typing import ArrayLike
+
+import warnings
+warnings.filterwarnings('ignore')
+
+# Import all the necessary hyperparameters and configurations for training.
+from pix2pixHD.options.train_options import TrainOptions
+from pix2pixHD.options.test_options import TestOptions
+
+# Import Pytorch dataloader and transforms.
+from pix2pixHD.data.data_loader_dlmbl import CreateDataLoader
+
+# Import the model architecture.
+from pix2pixHD.models import create_model
+
+# Import helper functions for visualization and processing.
+from pix2pixHD.util.visualizer import Visualizer
+from pix2pixHD.util import util
+
+# Import train script.
+from pix2pixHD.train_dlmbl import train as train_model
+from pix2pixHD.test_dlmbl import inference as inference_model
+from pix2pixHD.test_dlmbl import sampling
+
+# pytorch lightning wrapper for Tensorboard.
+from torch.utils.tensorboard import SummaryWriter
+
+# Import the same evaluation metrics as in the previous section.
+from viscy.evaluation.evaluation_metrics import mean_average_precision
+from torchmetrics.functional import accuracy, dice, jaccard_index
+
+# Initialize the default options and parse the arguments.
+opt = TrainOptions().parse()
+# Set the seed for reproducibility.
+util.set_seed(42)
+# Set the experiment folder name.
+translation_task = "nuclei" # or "cyto" depending on your choice of target for virtual stain.
+opt.name = "dlmbl_vsnuclei"
+# Path to store all the logs.
+top_dir = Path("mnt/efs/dlmbl")
+top_dir = top_dir/Path(f"data/06_image_translation/part2")
+opt.checkpoints_dir = top_dir/"GAN_code/GANs_MI2I/new_training_runs/"
+Path(f'{opt.checkpoints_dir}/{opt.name}').mkdir(parents=True, exist_ok=True)
+output_image_folder = top_dir/"tiff_files/"
+# Initalize the tensorboard writer.
+writer = SummaryWriter(log_dir=f'{opt.checkpoints_dir}/{opt.name}')
+# %% [markdown]
+"""
+## 1.1 Load Dataset & Configure Dataloaders.
+Having already downloaded and split our training, validation and test sets we now need to load the data into the model. We will use the Pytorch DataLoader class to load the data in batches. The DataLoader class is an iterator that provides a consistent way to load data in batches. We will also use the CreateDataLoader class to load the data in the correct format for the Pix2PixHD GAN.
+"""
+# %%
+# Initialize the Dataset and Dataloaders.
+
+## Define Dataset & Dataloader options.
+opt.dataroot = output_image_folder
+opt.data_type = 16 # Data type of the images.
+opt.loadSize = 512 # Size of the loaded phase image.
+opt.input_nc = 1 # Number of input channels.
+opt.output_nc = 1 # Number of output channels.
+opt.resize_or_crop = "none" # Scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none].
+opt.target = "nuclei" # or "cyto" depending on your choice of target for virtual stain.
+opt.batchSize = 16
+# Load Training Set for input into model
+opt.isTrain = True
+train_dataloader = CreateDataLoader(opt)
+dataset_train = train_dataloader.load_data()
+print(f"Total Training Images = {len(train_dataloader)}")
+
+# Load Val Set
+opt.phase = "val"
+val_dataloader = CreateDataLoader(opt)
+dataset_val = val_dataloader.load_data()
+print(f"Total Validation Images = {len(val_dataloader)}")
+opt.phase= "train"
+
+# Plot a sample image from the training set.
+# %% [markdown]
+"""
+## Configure Pix2PixHD GAN and train to predict nuclei from phase.
+Having loaded the data into the model we can now train the Pix2PixHD GAN to predict nuclei from phase. We will use the following hyperparameters to train the model:
+
+"""
+# %%
+# Define the parameters for the Generator.
+opt.ngf = 64 # Number of filters in the generator.
+opt.n_downsample_global = 4 # Number of downsampling layers in the generator.
+opt.n_blocks_global = 9 # Number of residual blocks in the generator.
+opt.n_blocks_local = 3 # Number of residual blocks in the generator.
+opt.n_local_enhancers = 1 # Number of local enhancers in the generator.
+
+# Define the parameters for the Discriminators.
+opt.num_D = 3 # Number of discriminators.
+opt.n_layers_D = 3 # Number of layers in the discriminator.
+opt.ndf = 32 # Number of filters in the discriminator.
+
+# Define general training parameters.
+opt.gpu_ids = [0] # GPU ids to use.
+opt.norm = "instance" # Normalization layer in the generator.
+opt.use_dropout = "" # Use dropout in the generator (fixed at 0.2).
+# Create a visualizer to perform image processing and visualization
+visualizer = Visualizer(opt)
+
+# Here will first start training a model from scrach however we can continue to train from a previously trained model by setting the following parameters.
+opt.continue_train = False
+if opt.continue_train:
+ iter_path = os.path.join(opt.checkpoints_dir, opt.name, "iter.txt")
+ try:
+ start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=",", dtype=int)
+ except:
+ start_epoch, epoch_iter = 1, 0
+ print("Resuming from epoch %d at iteration %d" % (start_epoch, epoch_iter))
+else:
+ start_epoch, epoch_iter = 1, 0
+
+print('------------ Options -------------')
+for k, v in sorted(vars(opt).items()):
+ print('%s: %s' % (str(k), str(v)))
+print('-------------- End ----------------')
+
+# Set the number of epoch to be 1 for demonstration purposes
+opt.n_epochs = 2
+# Initialize the model
+phase2nuclei_model = create_model(opt)
+# Define Optimizers for G and D
+optimizer_G, optimizer_D = (
+ phase2nuclei_model.module.optimizer_G,
+ phase2nuclei_model.module.optimizer_D,
+)
+# %%
+train_model(
+ opt,
+ phase2nuclei_model,
+ visualizer,
+ dataset_train,
+ dataset_val,
+ optimizer_G,
+ optimizer_D,
+ start_epoch,
+ epoch_iter,
+ writer,
+)
+# %% [markdown]
+"""
+
+
+## A heads up of what to expect from the training...
+
+<<<<<<< HEAD
+ - Visualise results : We can observe how the performance improves over time using the images tab and the sliding window.
+=======
+- Visualise results: We can observe how the performance improves over time using the images tab and the sliding window.
+>>>>>>> ff0946289bce4234aedb07af642c943d6d40dd24
+
+- Discriminator Predicted Probabilities: We plot the discriminator's predicted probabilities that the phase with fluorescence is phase and fluorescence and that the phase with virtual stain is phase with virtual stain. It is typically trained until the discriminator can no longer classify whether or not the generated images are real or fake better than a random guess (p(0.5)). We plot this for both the training and validation datasets.
+
+<<<<<<< HEAD
+ - Adversarial Loss : We can formulate the adversarial loss as a Least Squared Error Loss in which for real data the discriminator should output a value close to 1 and for fake data a value close to 0. The generator's goal is to make the discriminator output a value as close to 1 for fake data. We plot the least squared error loss.
+
+ - Feature Matching Loss : Both networks are also trained using the generator feature matching loss which encourages the generator to produce images that contain similar statistics to the real images at each scale. We also plot the feature matching L1 loss for the training and validation sets together to observe the performance and how the model is fitting the data.
+=======
+- Adversarial Loss: We can formulate the adversarial loss as a Least Squared Error Loss in which for real data the discriminator should output a value close to 1 and for fake data a value close to 0. The generator's goal is to make the discriminator output a value as close to 1 for fake data. We plot the least squared error loss.
+
+- Feature Matching Loss: Both networks are also trained using the generator feature matching loss which encourages the generator to produce images that contain similar statistics to the real images at each scale. We also plot the feature matching L1 loss for the training and validation sets together to observe the performance and how the model is fitting the data.
+>>>>>>> ff0946289bce4234aedb07af642c943d6d40dd24
+
+This implementation allows for the turning on/off of the least-square loss term by setting the opt.no_lsgan flag to the model options. As well as the turning off of the feature matching loss term by setting the opt.no_ganFeat_loss flag to the model options. Something you might want to explore in the next section!
+
+"""
+# %% [markdown]
+"""
+
+
+## Checkpoint 1
+
+Congratulations! You should now have a better understanding of how to train a Pix2PixHD GAN model for translating from phase to nuclei. You should also have a good understanding of the different loss components of Pix2PixHD GAN and how they are used to train the model.
+
+
+"""
+# %%
+"""
+# Part 2: Load & Assess trained Pix2PixGAN using tensorboard, discuss performance of the model.
+--------------------------------------------------
+Learning goals:
+- Load a pre-trained Pix2PixHD GAN model for either phase to nuclei.
+- Discuss the loss components of Pix2PixHD GAN and how they are used to train the model.
+- Evaluate the fit of the model on the train and validation datasets.
+
+In this part, we will evaluate the performance of the pre-trained model. We will begin by looking qualitatively at the model predictions, then dive into the different loss plots. We will discuss the implications of different hyper-parameter combinations for the performance of the model.
+
+If you are having issues loading the tensorboard session click "Launch TensorBoard session". You should then be able to add the log_dir path below and a tensorboard session shouls then load.
+"""
+# %%
+log_dir = f"{top_dir}/model_tensorboard/{opt.name}/"
+%reload_ext tensorboard
+%tensorboard --logdir $log_dir
+
+<<<<<<< HEAD
+
+# %% [markdown]
+"""
+
+=======
+# %% [markdown]
+#
+## Training Results
+#Please note down your thoughts about the following questions...
+#
+# **- What do you notice about the virtual staining predictions? How do they appear compared to the regression-based approach? Can you spot any hallucinations?**
+# **- What do you notice about the probabilities of the discriminators? How do the values compare during training compared to validation?**
+# **- What do you notice about the feature matching L1 loss?**
+# **- What do you notice about the least-square loss?**
+#
+
+# %% [markdown]
+#
+#
+#
+>>>>>>> ff0946289bce4234aedb07af642c943d6d40dd24
+## Checkpoint 2
+#Congratulations! You should now have a better understanding the different loss components of Pix2PixHD GAN and how they are used to train the model. You should also have a good understanding of the fit of the model during training on the training and validation datasets.
+#
+#
+
+
+# %% [markdown]
+"""
+# Part 3: Evaluate performance of the virtual staining on unseen data.
+--------------------------------------------------
+## Evaluate the performance of the model.
+We now look at the same metrics of performance of the previous model. We typically evaluate the model performance on a held out test data.
+
+Steps:
+- Define our model parameters for the pre-trained model (these are the same parameters as shown in earlier cells but copied here for clarity).
+- Load the test data.
+
+We will first load the test data using the same format as the training and validation data. We will then use the model to sample a virtual nuclei staining soltuion from the phase image. We will then evaluate the performance of the model using the following metrics:
+
+Pixel-level metrics:
+- [Peak-Signal-to-Noise-Ratio (PSNR)](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio).
+- [Structural Similarity Index Measure (SSIM)](https://en.wikipedia.org/wiki/Structural_similarity).
+- [Pearson Correlation Coefficient (PCC)](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
+
+Instance-level metrics via [Cellpose masks](https://cellpose.org/):
+- [Accuracy](https://en.wikipedia.org/wiki/Accuracy_and_precision#In_binary_classification)
+- [Jaccard Index](https://en.wikipedia.org/wiki/Jaccard_index)
+- [Dice Score](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)
+- [Mean Average Precision](https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision)
+"""
+
+# %%
+opt = TestOptions().parse(save=False)
+
+# Define the parameters for the dataset.
+opt.dataroot = output_image_folder
+opt.data_type = 16 # Data type of the images.
+opt.loadSize = 512 # Size of the loaded phase image.
+opt.input_nc = 1 # Number of input channels.
+opt.output_nc = 1 # Number of output channels.
+opt.target = "nuclei" # "nuclei" or "cyto" depending on your choice of target for virtual stain.
+opt.resize_or_crop = "none" # Scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none].
+opt.batchSize = 1 # Batch size for training
+
+# Define the model parameters for the pre-trained model.
+
+# Define the parameters for the Generator.
+opt.ngf = 64 # Number of filters in the generator.
+opt.n_downsample_global = 4 # Number of downsampling layers in the generator.
+opt.n_blocks_global = 9 # Number of residual blocks in the generator.
+opt.n_blocks_local = 3 # Number of residual blocks in the generator.
+opt.n_local_enhancers = 1 # Number of local enhancers in the generator.
+
+# Define the parameters for the Discriminators.
+opt.num_D = 3 # Number of discriminators.
+opt.n_layers_D = 3 # Number of layers in the discriminator.
+opt.ndf = 32 # Number of filters in the discriminator.
+
+# Define general training parameters.
+opt.gpu_ids= [0] # GPU ids to use.
+opt.norm = "instance" # Normalization layer in the generator.
+opt.use_dropout = "" # Use dropout in the generator (fixed at 0.2).
+opt.batchSize = 8 # Batch size.
+
+# Define loss functions.
+opt.no_vgg_loss = "" # Turn off VGG loss
+opt.no_ganFeat_loss = "" # Turn off feature matching loss
+opt.no_lsgan = "" # Turn off least square loss
+
+# Additional Inference parameters
+opt.name = f"dlmbl_vsnuclei"
+opt.how_many = 112 # Number of images to generate.
+opt.checkpoints_dir = f"{top_dir}/model_weights/" # Path to the model checkpoints.
+opt.results_dir = f"{top_dir}/GAN_code/GANs_MI2I/pre_trained/{opt.name}/inference_results/" # Path to store the results.
+opt.which_epoch = "latest" # or specify the epoch number "40"
+opt.phase = "test"
+
+opt.nThreads = 1 # test code only supports nThreads = 1
+opt.batchSize = 1 # test code only supports batchSize = 1
+opt.serial_batches = True # no shuffle
+opt.no_flip = True # no flip
+Path(opt.results_dir).mkdir(parents=True, exist_ok=True)
+
+# Load the test data.
+test_data_loader = CreateDataLoader(opt)
+test_dataset = test_data_loader.load_data()
+visualizer = Visualizer(opt)
+print(f"Total Test Images = {len(test_data_loader)}")
+# Load pre-trained model
+model = create_model(opt)
+
+# %%
+# Generate & save predictions in the results directory.
+inference_model(test_dataset, opt, model)
+
+# %%
+# Gather results for evaluation
+virtual_stain_paths = sorted([i for i in Path(opt.results_dir).glob("**/*.tiff")])
+target_stain_paths = sorted([i for i in Path(f"{output_image_folder}/{translation_task}/test/").glob("**/*.tiff")])
+phase_paths = sorted([i for i in Path(f"{output_image_folder}/input/test/").glob("**/*.tiff")])
+assert (len(virtual_stain_paths) == len(target_stain_paths) == len(phase_paths)
+), f"Number of images do not match. {len(virtual_stain_paths)},{len(target_stain_paths)} {len(phase_paths)} "
+
+# Create arrays to store the images.
+virtual_stains = np.zeros((len(virtual_stain_paths), 512, 512))
+target_stains = virtual_stains.copy()
+phase_images = virtual_stains.copy()
+# Load the images and store them in the arrays.
+for index, (v_path, t_path, p_path) in tqdm(
+ enumerate(zip(virtual_stain_paths, target_stain_paths, phase_paths))
+):
+ virtual_stain = imread(v_path)
+ phase_image = imread(p_path)
+ target_stain = imread(t_path)
+ # Append the images to the arrays.
+ phase_images[index] = phase_image
+ target_stains[index] = target_stain
+ virtual_stains[index] = virtual_stain
+
+# %% [markdown] tags=[]
+"""
+
+
+### Task 3.1 Visualise the results of the model on the test set.
+
+Create a matplotlib plot that visalises random samples of the phase images, target stains, and virtual stains.
+If you can incorporate the crop function below to zoom in on the images that would be great!
+
+"""
+# %%
+# Define a function to crop the images so we can zoom in.
+def crop(img, crop_size, loc='center'):
+ """
+ Crop the input image.
+
+ Parameters:
+ img (ndarray): The image to be cropped.
+ crop_size (int): The size of the crop.
+ loc (str): The type of crop to perform. Can be 'center' or 'random'.
+
+ Returns:
+ ndarray: The cropped image array.
+ """
+ # Dimension of input array
+ width, height = img.shape
+
+ # Generate random coordinates for the crop
+ max_y = height - crop_size
+ max_x = max_y
+
+ if loc == 'random':
+ start_y = np.random.randint(0, max_y + 1)
+ start_x = np.random.randint(0, max_x + 1)
+ end_y = start_y + crop_size
+ end_x = start_x + crop_size
+ elif loc == 'center':
+ start_x = (width - crop_size) // 2
+ start_y = (height - crop_size) // 2
+ end_y = height - start_y
+ end_x = width - start_x
+ else:
+ raise ValueError(f'Unknown crop type {loc}')
+
+ # Crop array using slicing
+ crop_array = img[start_x:end_x, start_y:end_y]
+ return crop_array
+
+# %% tags=["task"]
+##########################
+######## TODO ########
+##########################
+
+def visualise_results():
+ # Your code here
+ pass
+
+
+# %% tags=["solution"]
+
+##########################
+######## Solution ########
+##########################
+
+def visualise_results(phase_images, target_stains, virtual_stains, crop_size=None, loc='center'):
+ """
+ Visualizes the results of image processing by displaying the phase images, target stains, and virtual stains.
+ Parameters:
+ - phase_images (np.array): Array of phase images.
+ - target_stains (np.array): Array of target stains.
+ - virtual_stains (np.array): Array of virtual stains.
+ - crop_size (int, optional): Size of the crop. Defaults to None.
+ - type (str, optional): Type of crop. Defaults to 'center' but can be 'random.
+ Returns:
+ None
+ """
+ fig, axes = plt.subplots(5, 3, figsize=(15, 20))
+ sample_indices = np.random.choice(len(phase_images), 5)
+ for index,sample in enumerate(sample_indices):
+ if crop_size:
+ phase_image = crop(phase_images[index], crop_size, loc)
+ target_stain = crop(target_stains[index], crop_size, loc)
+ virtual_stain = crop(virtual_stains[index], crop_size, loc)
+ else:
+ phase_image = phase_images[index]
+ target_stain = target_stains[index]
+ virtual_stain = virtual_stains[index]
+
+ axes[index, 0].imshow(phase_image, cmap="gray")
+ axes[index, 0].set_title("Phase")
+ axes[index, 1].imshow(
+ target_stain,
+ cmap="gray",
+ vmin=np.percentile(target_stain, 1),
+ vmax=np.percentile(target_stain, 99),
+ )
+ axes[index, 1].set_title("Target Fluorescence ")
+ axes[index, 2].imshow(
+ virtual_stain,
+ cmap="gray",
+ vmin=np.percentile(virtual_stain, 1),
+ vmax=np.percentile(virtual_stain, 99),
+ )
+ axes[index, 2].set_title("Virtual Stain")
+ for ax in axes.flatten():
+ ax.axis("off")
+ plt.tight_layout()
+ plt.show()
+
+visualise_results(phase_images, target_stains,virtual_stains,crop_size=None)
+
+# %% [markdown] tags=[]
+"""
+
+
+### Task 3.2 Compute pixel-level metrics
+
+Compute the pixel-level metrics for the virtual stains and target stains.
+
+The following code will compute the following:
+- the pixel-based metrics (Pearson correlation, SSIM, PSNR) for the virtual stains and target stains.
+
+
+"""
+# %%
+
+# Define the function to perform minmax normalization which is required for the pixel-level metrics.
+def min_max_scale(input):
+ return (input - np.min(input)) / (np.max(input) - np.min(input))
+
+# Create a dataframe to store the pixel-level metrics.
+test_pixel_metrics = pd.DataFrame(
+ columns=["model", "fov","pearson_nuc", "ssim_nuc", "psnr_nuc"]
+)
+
+# Compute the pixel-level metrics.
+for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains))):
+ fov = str(virtual_stain_paths[i]).split("/")[-1].split(".")[0]
+ minmax_norm_target = min_max_scale(target_stain)
+ minmax_norm_predicted = min_max_scale(predicted_stain)
+
+ # Compute SSIM
+ ssim_nuc = metrics.structural_similarity(
+ minmax_norm_target, minmax_norm_predicted, data_range=1
+ )
+ # Compute Pearson correlation
+ pearson_nuc = np.corrcoef(
+ minmax_norm_target.flatten(), minmax_norm_predicted.flatten()
+ )[0, 1]
+ # Compute PSNR
+ psnr_nuc = metrics.peak_signal_noise_ratio(
+ minmax_norm_target, minmax_norm_predicted, data_range=1
+ )
+
+ test_pixel_metrics.loc[len(test_pixel_metrics)] = {
+ "model": "pix2pixHD",
+ "fov":fov,
+ "pearson_nuc": pearson_nuc,
+ "ssim_nuc": ssim_nuc,
+ "psnr_nuc": psnr_nuc,
+ }
+
+test_pixel_metrics.boxplot(
+ column=["pearson_nuc", "ssim_nuc"],
+ rot=30,
+)
+# %%
+test_pixel_metrics.boxplot(
+ column=["psnr_nuc"],
+ rot=30,
+)
+# %%
+test_pixel_metrics.head()
+# %% [markdown]
+"""
+## Inference Pixel-level Results
+Please note down your thoughts about the following questions...
+
+- What do these metrics tells us about the performance of the model?
+- How do the pixel-level metrics compare to the regression-based approach?
+- Could these metrics be skewed by the presence of hallucinations or background pilxels in the virtual stains?
+
+"""
+# %% [markdown]
+
+"""
+
+
+### Task 3.3 Compute instance-level metrics
+
+- Compute the instance-level metrics for the virtual stains and target stains.
+- Instance metrics include the accuracy (average correct predictions with 0.5 threshold), jaccard index (intersection over union (IoU)) dice score (2x intersection over union), mean average precision, mean average precision at 50% IoU, mean average precision at 75% IoU, and mean average recall at 100% IoU.
+
+
+
+"""
+# %%
+
+# Use the same function as previous part to extract the nuclei masks from pre-trained cellpose model.
+def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.ShortTensor]:
+ # NOTE these are hardcoded for this notebook and A549 dataset
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ cp_nuc_kwargs = {
+ "diameter": 65,
+ "cellprob_threshold": 0.0,
+ }
+ cellpose_model = models.CellposeModel(
+ gpu=True, model_type='nuclei', device=torch.device(device)
+ )
+ pred_label, _, _ = cellpose_model.eval(prediction, **cp_nuc_kwargs)
+ target_label, _, _ = cellpose_model.eval(target, **cp_nuc_kwargs)
+
+ pred_label = pred_label.astype(np.int32)
+ target_label = target_label.astype(np.int32)
+ pred_label = torch.ShortTensor(pred_label)
+ target_label = torch.ShortTensor(target_label)
+
+ return (pred_label,target_label)
+
+# Define dataframe to store the segmentation metrics.
+test_segmentation_metrics= pd.DataFrame(
+ columns=["model", "fov","masks_per_fov","accuracy","dice","jaccard","mAP","mAP_50","mAP_75","mAR_100"]
+)
+# Define tuple to store the segmentation results. Each value in the tuple is a dictionary containing the model name, fov, predicted label, predicted stain, target label, and target stain.
+segmentation_results = ()
+
+for i, (target_stain, predicted_stain) in tqdm(enumerate(zip(target_stains, virtual_stains)),desc='Computing Metrics',total=len(target_stains)):
+ fov = str(virtual_stain_paths[i]).split("/")[-1].split(".")[0]
+ minmax_norm_target = min_max_scale(target_stain)
+ minmax_norm_predicted = min_max_scale(predicted_stain)
+ # Compute the segmentation masks.
+ pred_label, target_label = cellpose_segmentation(minmax_norm_predicted, minmax_norm_target)
+ # Binary labels
+ pred_label_binary = pred_label > 0
+ target_label_binary = target_label > 0
+
+ # Use Coco metrics to get mean average precision
+ coco_metrics = mean_average_precision(pred_label, target_label)
+ # Find unique number of labels
+ num_masks_fov = len(np.unique(pred_label))
+ # Find unique number of labels
+ num_masks_fov = len(np.unique(pred_label))
+ # Compute the segmentation metrics.
+ test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {
+ "model": "pix2pixHD",
+ "fov":fov,
+ "masks_per_fov": num_masks_fov,
+ "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(),
+ "dice": dice(pred_label_binary, target_label_binary).item(),
+ "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(),
+ "mAP":coco_metrics["map"].item(),
+ "mAP_50":coco_metrics["map_50"].item(),
+ "mAP_75":coco_metrics["map_75"].item(),
+ "mAR_100":coco_metrics["mar_100"].item()
+ }
+ # Store the segmentation results.
+ segmentation_result = {
+ "model": "pix2pixHD",
+ "fov":fov,
+ "phase_image": phase_images[i],
+ "pred_label": pred_label,
+ "pred_stain": predicted_stain,
+ "target_label": target_label,
+ "target_stain": target_stain,
+ }
+ segmentation_results += (segmentation_result,)
+
+test_segmentation_metrics.head()
+# %%
+# Define function to visualize the segmentation results.
+def visualise_results_and_masks(segmentation_results: Tuple[dict], segmentation_metrics: pd.DataFrame, rows: int = 5, crop_size: int = None, crop_type: str = 'center'):
+
+ # Sample a subset of the segmentation results.
+ sample_indices = np.random.choice(len(phase_images),rows)
+ print(sample_indices)
+ segmentation_metrics = segmentation_metrics.iloc[sample_indices,:]
+ segmentation_results = [segmentation_results[i] for i in sample_indices]
+ # Define the figure and axes.
+ fig, axes = plt.subplots(rows, 5, figsize=(rows*3, 15))
+
+ # Visualize the segmentation results.
+ for i in range(len((segmentation_results))):
+ segmentation_metric = segmentation_metrics.iloc[i]
+ result = segmentation_results[i]
+ phase_image = result["phase_image"]
+ target_stain = result["target_stain"]
+ target_label = result["target_label"]
+ pred_stain = result["pred_stain"]
+ pred_label = result["pred_label"]
+ # Crop the images if required. Zoom into instances
+ if crop_size is not None:
+ phase_image = crop(phase_image, crop_size, crop_type)
+ target_stain = crop(target_stain, crop_size, crop_type)
+ target_label = crop(target_label, crop_size, crop_type)
+ pred_stain = crop(pred_stain, crop_size, crop_type)
+ pred_label = crop(pred_label, crop_size, crop_type)
+
+ axes[i, 0].imshow(phase_image, cmap="gray")
+ axes[i, 0].set_title("Phase")
+ axes[i, 1].imshow(
+ target_stain,
+ cmap="gray",
+ vmin=np.percentile(target_stain, 1),
+ vmax=np.percentile(target_stain, 99),
+ )
+ axes[i, 1].set_title("Target Fluorescence")
+ axes[i, 2].imshow(pred_stain, cmap="gray")
+ axes[i, 2].set_title("Virtual Stain")
+ axes[i, 3].imshow(target_label, cmap="inferno")
+ axes[i, 3].set_title("Target Fluorescence Mask")
+ axes[i, 4].imshow(pred_label, cmap="inferno")
+ # Add Metric values to the title
+ axes[i, 4].set_title(f"Virtual Stain Mask\nAcc:{segmentation_metric['accuracy']:.2f} Dice:{segmentation_metric['dice']:.2f}\nJaccard:{segmentation_metric['jaccard']:.2f} MAP:{segmentation_metric['mAP']:.2f}")
+ # Turn off the axes.
+ for ax in axes.flatten():
+ ax.axis("off")
+
+ plt.tight_layout()
+ plt.show()
+
+visualise_results_and_masks(segmentation_results,test_segmentation_metrics, crop_size=256, crop_type='center')
+
+# %% [markdown]
+# %% [markdown]
+"""
+## Inference Instance-level Results
+Please note down your thoughts about the following questions...
+
+- What do these metrics tells us about the performance of the model?
+- How does the performance compare to when looking at pixel-level metrics?
+
+"""
+# %% [markdown]
+"""
+
+
+## Checkpoint 3
+
+Congratulations! You have generated predictions from a pre-trained model and evaluated the performance of the model on unseen data. You have computed pixel-level metrics and instance-level metrics to evaluate the performance of the model. You may have also began training your own Pix2PixHD GAN models with alternative hyperparameters.
+
+
+
+## Checkpoint 4
+
+Congratulations! You should now have a better understanding of the difference in performance for image translation when approaching the problem using a regression vs. generative modelling approaches!
+
+
+"""
+# %% [markdown]
+"""
+# Part 5: BONUS: Sample different virtual staining solutions from the GAN using MC-Dropout and explore the uncertainty in the virtual stain predictions.
+--------------------------------------------------
+Steps:
+- Load the pre-trained model.
+- Generate multiple predictions for the same input image.
+- Compute the pixel-wise variance across the predictions.
+- Visualise the pixel-wise variance to explore the uncertainty in the virtual stain predictions.
+
+"""
+# %%
+# Use the same model and dataloaders as before.
+# Load the test data.
+test_data_loader = CreateDataLoader(opt)
+test_dataset = test_data_loader.load_data()
+visualizer = Visualizer(opt)
+
+# Load pre-trained model
+opt.variational_inf_runs = 100 # Number of samples per phase input
+opt.variation_inf_path = f"./GAN_code/GANs_MI2I/pre_trained/{opt.name}/samples/" # Path to store the samples.
+opt.results_dir = f"{top_dir}/GAN_code/GANs_MI2I/pre_trained/dlmbl_vsnuclei/sampling_results"
+opt.dropout_variation_inf = True # Use dropout during inference.
+model = create_model(opt)
+# Generate & save predictions in the variation_inf_path directory.
+sampling(test_dataset, opt, model)
+
+# %%
+# Visualise Samples
+samples = sorted([i for i in Path(f"{opt.results_dir}").glob("**/*.tif*")])
+assert len(samples) == 5
+# Create arrays to store the images.
+sample_images = np.zeros((5,100, 512, 512)) # (samples, images, height, width)
+# Load the images and store them in the arrays.
+for index, sample_path in tqdm(enumerate(samples)):
+ sample_image = imread(sample_path)
+ # Append the images to the arrays.
+ sample_images[index] = sample_image
+
+# %%
+# Create a matplotlib plot with animation through images.
+def animate_images(images):
+ # Expecting images to have shape (frames, height, width)
+ fig, ax = plt.subplots()
+ ax.axis('off')
+
+ # Make sure images are in (frames, height, width) order
+ images = images.transpose(0, 2, 1) if images.shape[1] == images.shape[2] else images
+
+ imgs = []
+ for i in range(min(100, len(images))): # Ensure you don't exceed the number of frames
+ im = ax.imshow(images[i], animated=True)
+ imgs.append([im])
+
+ ani = animation.ArtistAnimation(fig, imgs, interval=100, blit=False, repeat_delay=1000)
+
+ # Display the animation
+ # plt.close(fig)
+ display(HTML(ani.to_jshtml()))
+
+# Example call with sample_images[0]
+animate_images(sample_images[0])
+
+# %% [markdown]
+"""
+
+
+## Checkpoint 5
+
+Congratulations! This is the end of the conditional generative modelling approach to image translation notebook. You have trained and examined the loss components of Pix2PixHD GAN. You have compared the results of a regression-based approach vs. generative modelling approach and explored the variability in virtual staining solutions. I hope you have enjoyed learning experience!
+
+"""
diff --git a/setup.sh b/setup.sh
deleted file mode 100644
index 9472b16..0000000
--- a/setup.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/usr/bin/env -S bash -i
-
-START_DIR=$(pwd)
-
-# Create mamba environment
-mamba create -y --name 04_image_translation python=3.10
-
-# Install ipykernel in the environment.
-mamba install -y ipykernel nbformat nbconvert black jupytext ipywidgets --name 04_image_translation
-# Specifying the environment explicitly.
-# mamba activate sometimes doesn't work from within shell scripts.
-
-# install viscy and its dependencies`s in the environment using pip.
-mkdir -p ~/code/
-cd ~/code/
-git clone https://github.com/mehta-lab/viscy.git
-cd viscy
-git checkout 7c5e4c1d68e70163cf514d22c475da8ea7dc3a88 # Exercise is tested with this commit of viscy
-# Find path to the environment - mamba activate doesn't work from within shell scripts.
-ENV_PATH=$(conda info --envs | grep 04_image_translation | awk '{print $NF}')
-$ENV_PATH/bin/pip install ."[metrics]"
-
-# Create data directory
-mkdir -p ~/data/04_image_translation
-cd ~/data/04_image_translation
-wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_data_pyramid.tar.gz
-wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_test.tar.gz
-tar -xzf DLMBL2023_image_translation_data_pyramid.tar.gz
-tar -xzf DLMBL2023_image_translation_test.tar.gz
-
-# Change back to the starting directory
-cd $START_DIR
diff --git a/solution.ipynb b/solution.ipynb
deleted file mode 100644
index 5775c64..0000000
--- a/solution.ipynb
+++ /dev/null
@@ -1,1314 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "15d23751",
- "metadata": {
- "cell_marker": "\"\"\""
- },
- "source": [
- "# Image translation\n",
- "---\n",
- "\n",
- "Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco.\n",
- "\n",
- "In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. \n",
- "\n",
- "Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy).\n",
- "\n",
- "VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy.\n",
- "![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg)\n",
- "\n",
- "[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning\n",
- ". eLife](https://elifesciences.org/articles/55502).\n",
- "\n",
- "VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b5957320",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels.\n",
- "![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6c62db93",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "
\n",
- "The exercise is organized in 3 parts.\n",
- "\n",
- "* **Part 1** - Explore the data using tensorboard. Launch the training before lunch.\n",
- "* Lunch break - The model will continue training during lunch.\n",
- "* **Part 2** - Evaluate the training with tensorboard. Train another model.\n",
- "* **Part 3** - Tune the models to improve performance.\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "30739bb5",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖.\n",
- "\n",
- "\n",
- "Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~20 min on a typical AWS node. \n",
- "We will discuss your observations on google doc after checkpoints 2 and 3.\n",
- "\n",
- "The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "658e3b31",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "
\n",
- "Set your python kernel to 04_image_translation\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "be433880",
- "metadata": {
- "tags": [],
- "title": "Imports and paths"
- },
- "outputs": [],
- "source": [
- "%reload_ext tensorboard\n",
- "%tensorboard --logdir {log_dir} "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "57cf6b27",
- "metadata": {
- "cell_marker": "\"\"\""
- },
- "source": [
- "## Load Dataset.\n",
- "\n",
- "There should be 301 FOVs in the dataset (12 GB compressed).\n",
- "\n",
- "Each FOV consists of 3 channels of 2048x2048 images,\n",
- "saved in the \n",
- "High-Content Screening (HCS) layout\n",
- "specified by the Open Microscopy Environment Next Generation File Format\n",
- "(OME-NGFF).\n",
- "\n",
- "The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x.\n",
- "Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "08497569",
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = open_ome_zarr(data_path)\n",
- "\n",
- "print(f\"Number of positions: {len(list(dataset.positions()))}\")\n",
- "\n",
- "# Use the field and pyramid_level below to visualize data.\n",
- "row = 0\n",
- "col = 0\n",
- "field = 23 # TODO: Change this to explore data.\n",
- "\n",
- "# This dataset contains images at 3 resolutions.\n",
- "# '0' is the highest resolution\n",
- "# '1' is down-scaled 2x2,\n",
- "# '2' is down-scaled 4x4.\n",
- "# Such datasets are called image pyramids.\n",
- "pyaramid_level = 0\n",
- "\n",
- "# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.\n",
- "n_channels = len(dataset.channel_names)\n",
- "\n",
- "image = dataset[f\"{row}/{col}/{field}/{pyaramid_level}\"].numpy()\n",
- "print(f\"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}\")\n",
- "\n",
- "figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))\n",
- "\n",
- "for i in range(n_channels):\n",
- " for i in range(n_channels):\n",
- " channel_image = image[0, i, 0]\n",
- " # Adjust contrast to 0.5th and 99.5th percentile of pixel values.\n",
- " p_low, p_high = np.percentile(channel_image, (0.5, 99.5))\n",
- " channel_image = np.clip(channel_image, p_low, p_high)\n",
- " axes[i].imshow(channel_image, cmap=\"gray\")\n",
- " axes[i].axis(\"off\")\n",
- " axes[i].set_title(dataset.channel_names[i])\n",
- "plt.tight_layout()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8e46fc19",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.1\n",
- " \n",
- "Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining.\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7f6f3609",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 1
- },
- "source": [
- "## Explore the effects of augmentation on batch.\n",
- "\n",
- "VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.\n",
- " \n",
- "The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.\n",
- "- `source`: the input image, a tensor of size 1*1*Y*X\n",
- "- `target`: the target image, a tensor of size 2*1*Y*X\n",
- "- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1684d72e",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.2\n",
- "\n",
- "Setup the data loader and log several batches to tensorboard.\n",
- "\n",
- "Based on the tensorboard images, what are the two channels in the target image?\n",
- "\n",
- "Note: If tensorboard is not showing images, try refreshing and using the \"Images\" tab.\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "67211280",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define a function to write a batch to tensorboard log.\n",
- "\n",
- "def log_batch_tensorboard(batch, batchno, writer, card_name):\n",
- " \"\"\"\n",
- " Logs a batch of images to TensorBoard.\n",
- "\n",
- " Args:\n",
- " batch (dict): A dictionary containing the batch of images to be logged.\n",
- " writer (SummaryWriter): A TensorBoard SummaryWriter object.\n",
- " card_name (str): The name of the card to be displayed in TensorBoard.\n",
- "\n",
- " Returns:\n",
- " None\n",
- " \"\"\"\n",
- " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
- " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- "\n",
- " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
- " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
- " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
- " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " [N, C, H, W] = batch_phase.shape\n",
- " interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)\n",
- " interleaved_images[0::3, :] = batch_phase\n",
- " interleaved_images[1::3, :] = batch_nuclei\n",
- " interleaved_images[2::3, :] = batch_membrane\n",
- "\n",
- " grid = torchvision.utils.make_grid(interleaved_images, nrow=3)\n",
- "\n",
- " # add the grid to tensorboard\n",
- " writer.add_image(card_name, grid, batchno)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "73577e5a",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define a function to visualize a batch on jupyter, in case tensorboard is finicky \n",
- "\n",
- "def log_batch_jupyter(batch):\n",
- " \"\"\"\n",
- " Logs a batch of images on jupyter using ipywidget.\n",
- "\n",
- " Args:\n",
- " batch (dict): A dictionary containing the batch of images to be logged.\n",
- "\n",
- " Returns:\n",
- " None\n",
- " \"\"\"\n",
- " batch_phase = batch[\"source\"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.\n",
- " batch_size = batch_phase.shape[0]\n",
- " batch_membrane = batch[\"target\"][:, 1, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- " batch_nuclei = batch[\"target\"][:, 0, 0, :, :].unsqueeze(\n",
- " 1\n",
- " ) # batch_size x 1 x Y x X tensor.\n",
- "\n",
- " p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))\n",
- " batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))\n",
- " batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " p1, p99 = np.percentile(batch_phase, (0.1, 99.9))\n",
- " batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)\n",
- "\n",
- " plt.figure()\n",
- " fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10))\n",
- " [N, C, H, W] = batch_phase.shape\n",
- " for sample_id in range(batch_size):\n",
- " axes[sample_id, 0].imshow(batch_phase[sample_id,0])\n",
- " axes[sample_id, 1].imshow(batch_nuclei[sample_id,0])\n",
- " axes[sample_id, 2].imshow(batch_membrane[sample_id,0])\n",
- "\n",
- " for i in range(n_channels):\n",
- " axes[sample_id, i].axis(\"off\")\n",
- " axes[sample_id, i].set_title(dataset.channel_names[i])\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "398f4546",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "\n",
- "# Initialize the data module.\n",
- "\n",
- "BATCH_SIZE = 4\n",
- "# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.\n",
- "# More seriously, batch size does not have to be a power of 2.\n",
- "# See: https://sebastianraschka.com/blog/2022/batch-size-2.html\n",
- "\n",
- "data_module = HCSDataModule(\n",
- " data_path,\n",
- " source_channel=\"Phase\",\n",
- " target_channel=[\"Nuclei\", \"Membrane\"],\n",
- " z_window_size=1,\n",
- " split_ratio=0.8,\n",
- " batch_size=BATCH_SIZE,\n",
- " num_workers=8,\n",
- " architecture=\"2D\",\n",
- " yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.\n",
- " augment=False, # Turn off augmentation for now.\n",
- ")\n",
- "data_module.setup(\"fit\")\n",
- "\n",
- "print(\n",
- " f\"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}\"\n",
- ")\n",
- "train_dataloader = data_module.train_dataloader()\n",
- "\n",
- "# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.\n",
- "\n",
- "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
- "# Draw a batch and write to tensorboard.\n",
- "batch = next(iter(train_dataloader))\n",
- "log_batch_tensorboard(batch, 0, writer, \"augmentation/none\")\n",
- "writer.close()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a4c45450",
- "metadata": {},
- "source": [
- "Visualize directly on Jupyter ☄️, if your tensorboard is causing issues."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "73466d3f",
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib inline\n",
- "log_batch_jupyter(batch)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "19def8d6",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "## View augmentations using tensorboard."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97bdcbd8",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Here we turn on data augmentation and rerun setup\n",
- "data_module.augment = True\n",
- "data_module.setup(\"fit\")\n",
- "\n",
- "# get the new data loader with augmentation turned on\n",
- "augmented_train_dataloader = data_module.train_dataloader()\n",
- "\n",
- "# Draw batches and write to tensorboard\n",
- "writer = SummaryWriter(log_dir=f\"{log_dir}/view_batch\")\n",
- "augmented_batch = next(iter(augmented_train_dataloader))\n",
- "log_batch_tensorboard(augmented_batch, 0, writer, \"augmentation/some\")\n",
- "writer.close()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "247bf9c7",
- "metadata": {},
- "source": [
- "Visualize directly on Jupyter ☄️"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "de281de7",
- "metadata": {},
- "outputs": [],
- "source": [
- "log_batch_jupyter(augmented_batch)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "47cc91e4",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.3\n",
- "Can you tell what augmentation were applied from looking at the augmented images in Tensorboard?\n",
- "\n",
- "Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529).\n",
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5d05336e",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "## Train a 2D U-Net model to predict nuclei and membrane from phase.\n",
- "\n",
- "### Construct a 2D U-Net\n",
- "See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "964d5aae",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "# Create a 2D UNet.\n",
- "GPU_ID = 0\n",
- "BATCH_SIZE = 10\n",
- "YX_PATCH_SIZE = (512, 512)\n",
- "\n",
- "\n",
- "# Dictionary that specifies key parameters of the model.\n",
- "phase2fluor_config = {\n",
- " \"architecture\": \"2D\",\n",
- " \"num_filters\": [24, 48, 96, 192, 384],\n",
- " \"in_channels\": 1,\n",
- " \"out_channels\": 2,\n",
- " \"residual\": True,\n",
- " \"dropout\": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.\n",
- " \"task\": \"reg\", # reg = regression task.\n",
- "}\n",
- "\n",
- "phase2fluor_model = VSUNet(\n",
- " model_config=phase2fluor_config.copy(),\n",
- " batch_size=BATCH_SIZE,\n",
- " loss_function=torch.nn.functional.l1_loss,\n",
- " schedule=\"WarmupCosine\",\n",
- " log_num_samples=5, # Number of samples from each batch to log to tensorboard.\n",
- " example_input_yx_shape=YX_PATCH_SIZE,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "eabb4902",
- "metadata": {
- "cell_marker": "\"\"\"",
- "lines_to_next_cell": 0
- },
- "source": [
- "### Instantiate data module and trainer, test that we are setup to launch training."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9728a1c0",
- "metadata": {
- "lines_to_next_cell": 2
- },
- "outputs": [],
- "source": [
- "# Setup the data module.\n",
- "phase2fluor_data = HCSDataModule(\n",
- " data_path,\n",
- " source_channel=\"Phase\",\n",
- " target_channel=[\"Nuclei\", \"Membrane\"],\n",
- " z_window_size=1,\n",
- " split_ratio=0.8,\n",
- " batch_size=BATCH_SIZE,\n",
- " num_workers=8,\n",
- " architecture=\"2D\",\n",
- " yx_patch_size=YX_PATCH_SIZE,\n",
- " augment=True,\n",
- ")\n",
- "phase2fluor_data.setup(\"fit\")\n",
- "# fast_dev_run runs a single batch of data through the model to check for errors.\n",
- "trainer = VSTrainer(accelerator=\"gpu\", devices=[GPU_ID], fast_dev_run=True)\n",
- "\n",
- "# trainer class takes the model and the data module as inputs.\n",
- "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b837c6b2",
- "metadata": {},
- "source": [
- "## View model graph.\n",
- "\n",
- "PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "31665d0f",
- "metadata": {},
- "source": [
- "
\n",
- "\n",
- "### Task 1.4\n",
- "Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization?\n",
- "
\n",
- "\n",
- "### Task 1.5\n",
- "Start training by running the following cell. Check the new logs on the tensorboard.\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "26303693",
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "GPU_ID = 0\n",
- "n_samples = len(phase2fluor_data.train_dataset)\n",
- "steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.\n",
- "n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.\n",
- "\n",
- "trainer = VSTrainer(\n",
- " accelerator=\"gpu\",\n",
- " devices=[GPU_ID],\n",
- " max_epochs=n_epochs,\n",
- " log_every_n_steps=steps_per_epoch // 2,\n",
- " # log losses and image samples 2 times per epoch.\n",
- " logger=TensorBoardLogger(\n",
- " save_dir=log_dir,\n",
- " # lightning trainer transparently saves logs and model checkpoints in this directory.\n",
- " name=\"phase2fluor\",\n",
- " log_graph=True,\n",
- " ),\n",
- " ) \n",
- "# Launch training and check that loss and images are being logged on tensorboard.\n",
- "trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4260177d",
- "metadata": {
- "cell_marker": "\"\"\""
- },
- "source": [
- "
\n",
- "\n",
- "## Checkpoint 1\n",
- "\n",
- "Now the training has started,\n",
- "we can come back after a while and evaluate the performance!\n",
- "
-"""
-# %%
-"""
-# Part 1: Log training data to tensorboard, start training a model.
----------
-
-Learning goals:
-
-- Load the OME-zarr dataset and examine the channels.
-- Configure and understand the data loader.
-- Log some patches to tensorboard.
-- Initialize a 2D U-Net model for virtual staining
-- Start training the model to predict nuclei and membrane from phase.
-"""
-
-# %% Imports and paths
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import torch
-import torchview
-import torchvision
-from iohub import open_ome_zarr
-from lightning.pytorch import seed_everything
-from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
-from skimage import metrics # for metrics.
-
-# %% Imports and paths
-# pytorch lightning wrapper for Tensorboard.
-from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
-
-# HCSDataModule makes it easy to load data during training.
-from viscy.light.data import HCSDataModule
-
-# Trainer class and UNet.
-from viscy.light.engine import VSTrainer, VSUNet
-
-seed_everything(42, workers=True)
-
-# Paths to data and log directory
-data_path = Path(
- Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/")
-).expanduser()
-
-log_dir = Path("~/data/04_image_translation/logs/").expanduser()
-
-# Create log directory if needed, and launch tensorboard
-log_dir.mkdir(parents=True, exist_ok=True)
-
-# %% [markdown] tags=[]
-'''
-The next cell starts tensorboard within the notebook.
-
-
-If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com.
-
-You can also launch tensorboard in an independent tab (instead of in the notebook) by changing the `%` to `!`
-
-'''
-
-# %% Imports and paths tags=[]
-%reload_ext tensorboard
-%tensorboard --logdir {log_dir}
-
-# %% [markdown]
-"""
-## Load Dataset.
-
-There should be 301 FOVs in the dataset (12 GB compressed).
-
-Each FOV consists of 3 channels of 2048x2048 images,
-saved in the
-High-Content Screening (HCS) layout
-specified by the Open Microscopy Environment Next Generation File Format
-(OME-NGFF).
-
-The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x.
-Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein.
-"""
-
-# %%
-dataset = open_ome_zarr(data_path)
-
-print(f"Number of positions: {len(list(dataset.positions()))}")
-
-# Use the field and pyramid_level below to visualize data.
-row = 0
-col = 0
-field = 23 # TODO: Change this to explore data.
-
-# This dataset contains images at 3 resolutions.
-# '0' is the highest resolution
-# '1' is down-scaled 2x2,
-# '2' is down-scaled 4x4.
-# Such datasets are called image pyramids.
-pyaramid_level = 0
-
-# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.
-n_channels = len(dataset.channel_names)
-
-image = dataset[f"{row}/{col}/{field}/{pyaramid_level}"].numpy()
-print(f"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}")
-
-figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))
-
-for i in range(n_channels):
- for i in range(n_channels):
- channel_image = image[0, i, 0]
- # Adjust contrast to 0.5th and 99.5th percentile of pixel values.
- p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
- channel_image = np.clip(channel_image, p_low, p_high)
- axes[i].imshow(channel_image, cmap="gray")
- axes[i].axis("off")
- axes[i].set_title(dataset.channel_names[i])
-plt.tight_layout()
-
-# %% [markdown]
-#
-#
-# ### Task 1.1
-#
-# Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining.
-#
-
-# %% [markdown]
-"""
-## Explore the effects of augmentation on batch.
-
-VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.
-
-The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.
-- `source`: the input image, a tensor of size 1*1*Y*X
-- `target`: the target image, a tensor of size 2*1*Y*X
-- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample.
-"""
-
-# %% [markdown]
-#
-#
-# ### Task 1.2
-#
-# Setup the data loader and log several batches to tensorboard.
-#
-# Based on the tensorboard images, what are the two channels in the target image?
-#
-# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab.
-#
-
-# %%
-# Define a function to write a batch to tensorboard log.
-
-def log_batch_tensorboard(batch, batchno, writer, card_name):
- """
- Logs a batch of images to TensorBoard.
-
- Args:
- batch (dict): A dictionary containing the batch of images to be logged.
- writer (SummaryWriter): A TensorBoard SummaryWriter object.
- card_name (str): The name of the card to be displayed in TensorBoard.
-
- Returns:
- None
- """
- batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
- batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
- 1
- ) # batch_size x 1 x Y x X tensor.
- batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
- 1
- ) # batch_size x 1 x Y x X tensor.
-
- p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
- batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
-
- p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
- batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
-
- p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
- batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
-
- [N, C, H, W] = batch_phase.shape
- interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)
- interleaved_images[0::3, :] = batch_phase
- interleaved_images[1::3, :] = batch_nuclei
- interleaved_images[2::3, :] = batch_membrane
-
- grid = torchvision.utils.make_grid(interleaved_images, nrow=3)
-
- # add the grid to tensorboard
- writer.add_image(card_name, grid, batchno)
-
-
-# %%
-# Define a function to visualize a batch on jupyter, in case tensorboard is finicky
-
-def log_batch_jupyter(batch):
- """
- Logs a batch of images on jupyter using ipywidget.
-
- Args:
- batch (dict): A dictionary containing the batch of images to be logged.
-
- Returns:
- None
- """
- batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
- batch_size = batch_phase.shape[0]
- batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
- 1
- ) # batch_size x 1 x Y x X tensor.
- batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
- 1
- ) # batch_size x 1 x Y x X tensor.
-
- p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
- batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
-
- p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
- batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
-
- p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
- batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
-
- plt.figure()
- fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10))
- [N, C, H, W] = batch_phase.shape
- for sample_id in range(batch_size):
- axes[sample_id, 0].imshow(batch_phase[sample_id,0])
- axes[sample_id, 1].imshow(batch_nuclei[sample_id,0])
- axes[sample_id, 2].imshow(batch_membrane[sample_id,0])
-
- for i in range(n_channels):
- axes[sample_id, i].axis("off")
- axes[sample_id, i].set_title(dataset.channel_names[i])
- plt.tight_layout()
- plt.show()
-
-
-# %%
-
-# Initialize the data module.
-
-BATCH_SIZE = 4
-# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything.
-# More seriously, batch size does not have to be a power of 2.
-# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
-
-data_module = HCSDataModule(
- data_path,
- source_channel="Phase",
- target_channel=["Nuclei", "Membrane"],
- z_window_size=1,
- split_ratio=0.8,
- batch_size=BATCH_SIZE,
- num_workers=8,
- architecture="2D",
- yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations.
- augment=False, # Turn off augmentation for now.
-)
-data_module.setup("fit")
-
-print(
- f"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}"
-)
-train_dataloader = data_module.train_dataloader()
-
-# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
-
-writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
-# Draw a batch and write to tensorboard.
-batch = next(iter(train_dataloader))
-log_batch_tensorboard(batch, 0, writer, "augmentation/none")
-writer.close()
-
-
-# %% [markdown]
-# Visualize directly on Jupyter ☄️, if your tensorboard is causing issues.
-
-# %%
-%matplotlib inline
-log_batch_jupyter(batch)
-
-# %% [markdown]
-"""
-## View augmentations using tensorboard.
-"""
-# %%
-# Here we turn on data augmentation and rerun setup
-data_module.augment = True
-data_module.setup("fit")
-
-# get the new data loader with augmentation turned on
-augmented_train_dataloader = data_module.train_dataloader()
-
-# Draw batches and write to tensorboard
-writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
-augmented_batch = next(iter(augmented_train_dataloader))
-log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
-writer.close()
-
-# %% [markdown]
-# Visualize directly on Jupyter ☄️
-
-# %%
-log_batch_jupyter(augmented_batch)
-
-# %% [markdown]
-#
-#
-# ### Task 1.3
-# Can you tell what augmentation were applied from looking at the augmented images in Tensorboard?
-#
-# Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529).
-#
-
-# %% [markdown]
-"""
-## Train a 2D U-Net model to predict nuclei and membrane from phase.
-
-### Construct a 2D U-Net
-See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details.
-"""
-# %%
-# Create a 2D UNet.
-GPU_ID = 0
-BATCH_SIZE = 10
-YX_PATCH_SIZE = (512, 512)
-
-
-# Dictionary that specifies key parameters of the model.
-phase2fluor_config = {
- "architecture": "2D",
- "num_filters": [24, 48, 96, 192, 384],
- "in_channels": 1,
- "out_channels": 2,
- "residual": True,
- "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
- "task": "reg", # reg = regression task.
-}
-
-phase2fluor_model = VSUNet(
- model_config=phase2fluor_config.copy(),
- batch_size=BATCH_SIZE,
- loss_function=torch.nn.functional.l1_loss,
- schedule="WarmupCosine",
- log_num_samples=5, # Number of samples from each batch to log to tensorboard.
- example_input_yx_shape=YX_PATCH_SIZE,
-)
-
-
-# %% [markdown]
-"""
-### Instantiate data module and trainer, test that we are setup to launch training.
-"""
-# %%
-# Setup the data module.
-phase2fluor_data = HCSDataModule(
- data_path,
- source_channel="Phase",
- target_channel=["Nuclei", "Membrane"],
- z_window_size=1,
- split_ratio=0.8,
- batch_size=BATCH_SIZE,
- num_workers=8,
- architecture="2D",
- yx_patch_size=YX_PATCH_SIZE,
- augment=True,
-)
-phase2fluor_data.setup("fit")
-# fast_dev_run runs a single batch of data through the model to check for errors.
-trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
-
-# trainer class takes the model and the data module as inputs.
-trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
-
-
-# %% [markdown]
-# ## View model graph.
-#
-# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging.
-
-# %% [markdown]
-#
-#
-# ### Task 1.4
-# Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization?
-#
-
-# %%
-# visualize graph of phase2fluor model as image.
-model_graph_phase2fluor = torchview.draw_graph(
- phase2fluor_model,
- phase2fluor_data.train_dataset[0]["source"],
- depth=2, # adjust depth to zoom in.
- device="cpu",
-)
-# Print the image of the model.
-model_graph_phase2fluor.visual_graph
-
-# %% [markdown]
-"""
-
-
-### Task 1.5
-Start training by running the following cell. Check the new logs on the tensorboard.
-
-"""
-
-
-# %%
-
-GPU_ID = 0
-n_samples = len(phase2fluor_data.train_dataset)
-steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
-n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
-
-trainer = VSTrainer(
- accelerator="gpu",
- devices=[GPU_ID],
- max_epochs=n_epochs,
- log_every_n_steps=steps_per_epoch // 2,
- # log losses and image samples 2 times per epoch.
- logger=TensorBoardLogger(
- save_dir=log_dir,
- # lightning trainer transparently saves logs and model checkpoints in this directory.
- name="phase2fluor",
- log_graph=True,
- ),
- )
-# Launch training and check that loss and images are being logged on tensorboard.
-trainer.fit(phase2fluor_model, datamodule=phase2fluor_data)
-
-# %% [markdown]
-"""
-
-
-## Checkpoint 1
-
-Now the training has started,
-we can come back after a while and evaluate the performance!
-
-"""
-
-# %%
-"""
-# Part 2: Assess previous model, train fluorescence to phase contrast translation model.
---------------------------------------------------
-"""
-
-# %% [markdown]
-"""
-We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model:
-- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
-- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM).
-
-You should also look at the validation samples on tensorboard (hint: the experimental data in nuclei channel is imperfect.)
-"""
-
-# %% [markdown]
-"""
-
-
-### Task 2.1 Define metrics
-
-For each of the above metrics, write a brief definition of what they are and what they mean for this image translation task.
-
-
-
-### Task 2.2 Train fluorescence to phase contrast translation model
-
-Instantiate a data module, model, and trainer for fluorescence to phase contrast translation. Copy over the code from previous cells and update the parameters. Give the variables and paths a different name/suffix (fluor2phase) to avoid overwriting objects used to train phase2fluor models.
-
-"""
-# %% tags=[]
-##########################
-######## TODO ########
-##########################
-
-fluor2phase_data = HCSDataModule(
- # Your code here (copy from above and modify as needed)
-)
-fluor2phase_data.setup("fit")
-
-# Dictionary that specifies key parameters of the model.
-fluor2phase_config = {
- # Your config here
-}
-
-fluor2phase_model = VSUNet(
- # Your code here (copy from above and modify as needed)
-)
-
-trainer = VSTrainer(
- # Your code here (copy from above and modify as needed)
-)
-trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
-
-
-# Visualize the graph of fluor2phase model as image.
-model_graph_fluor2phase = torchview.draw_graph(
- fluor2phase_model,
- fluor2phase_data.train_dataset[0]["source"],
- depth=2, # adjust depth to zoom in.
- device="cpu",
-)
-model_graph_fluor2phase.visual_graph
-
-# %% tags=["solution"]
-
-##########################
-######## Solution ########
-##########################
-
-# The entire training loop is contained in this cell.
-
-fluor2phase_data = HCSDataModule(
- data_path,
- source_channel="Membrane",
- target_channel="Phase",
- z_window_size=1,
- split_ratio=0.8,
- batch_size=BATCH_SIZE,
- num_workers=8,
- architecture="2D",
- yx_patch_size=YX_PATCH_SIZE,
- augment=True,
-)
-fluor2phase_data.setup("fit")
-
-# Dictionary that specifies key parameters of the model.
-fluor2phase_config = {
- "architecture": "2D",
- "in_channels": 1,
- "out_channels": 1,
- "residual": True,
- "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data.
- "task": "reg", # reg = regression task.
- "num_filters": [24, 48, 96, 192, 384],
-}
-
-fluor2phase_model = VSUNet(
- model_config=fluor2phase_config.copy(),
- batch_size=BATCH_SIZE,
- loss_function=torch.nn.functional.mse_loss,
- schedule="WarmupCosine",
- log_num_samples=5,
- example_input_yx_shape=YX_PATCH_SIZE,
-)
-
-
-trainer = VSTrainer(
- accelerator="gpu",
- devices=[GPU_ID],
- max_epochs=n_epochs,
- log_every_n_steps=steps_per_epoch // 2,
- logger=TensorBoardLogger(
- save_dir=log_dir,
- # lightning trainer transparently saves logs and model checkpoints in this directory.
- name="fluor2phase",
- log_graph=True,
- ),
-)
-trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
-
-
-# Visualize the graph of fluor2phase model as image.
-model_graph_fluor2phase = torchview.draw_graph(
- fluor2phase_model,
- fluor2phase_data.train_dataset[0]["source"],
- depth=2, # adjust depth to zoom in.
- device="cpu",
-)
-model_graph_fluor2phase.visual_graph
-
-# %% [markdown] tags=[]
-"""
-
-
-### Task 2.3
-
-While your model is training, let's think about the following questions:
-- What is the information content of each channel in the dataset?
-- How would you use image translation models?
-- What can you try to improve the performance of each model?
-
-
-## Checkpoint 2
-When your model finishes training, please summarize hyperparameters and performance of your models in the [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
-
-
-"""
-
-# %% tags=[]
-"""
-# Part 3: Tune the models.
---------------------------------------------------
-
-Learning goals: Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model.
-"""
-
-
-# %% [markdown] tags=[]
-"""
-
-
-### Task 3.1
-
-- Choose a model you want to train (phase2fluor or fluor2phase).
-- Set up a configuration that you think will improve the performance of the model
-- Consider modifying the learning rate and see how it changes performance
-- Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop.
-- Add code to evaluate the model using Pearson Correlation and SSIM
-
-As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z)
-
-
-"""
-# %% tags=[]
-##########################
-######## TODO ########
-##########################
-
-tune_data = HCSDataModule(
- # Your code here (copy from above and modify as needed)
-)
-tune_data.setup("fit")
-
-# Dictionary that specifies key parameters of the model.
-tune_config = {
- # Your config here
-}
-
-tune_model = VSUNet(
- # Your code here (copy from above and modify as needed)
-)
-
-trainer = VSTrainer(
- # Your code here (copy from above and modify as needed)
-)
-trainer.fit(tune_model, datamodule=tune_data)
-
-
-# Visualize the graph of fluor2phase model as image.
-model_graph_tune = torchview.draw_graph(
- tune_model,
- tune_data.train_dataset[0]["source"],
- depth=2, # adjust depth to zoom in.
- device="cpu",
-)
-model_graph_tune.visual_graph
-
-
-# %% tags=["solution"]
-
-##########################
-######## Solution ########
-##########################
-
-phase2fluor_wider_config = {
- "architecture": "2D",
- # double the number of filters at each stage
- "num_filters": [48, 96, 192, 384, 768],
- "in_channels": 1,
- "out_channels": 2,
- "residual": True,
- "dropout": 0.1,
- "task": "reg",
-}
-
-phase2fluor_wider_model = VSUNet(
- model_config=phase2fluor_wider_config.copy(),
- batch_size=BATCH_SIZE,
- loss_function=torch.nn.functional.l1_loss,
- schedule="WarmupCosine",
- log_num_samples=5,
- example_input_yx_shape=YX_PATCH_SIZE,
-)
-
-
-trainer = VSTrainer(
- accelerator="gpu",
- devices=[GPU_ID],
- max_epochs=n_epochs,
- log_every_n_steps=steps_per_epoch,
- logger=TensorBoardLogger(
- save_dir=log_dir,
- name="phase2fluor",
- version="wider",
- log_graph=True,
- ),
- fast_dev_run=True,
-) # Set fast_dev_run to False to train the model.
-trainer.fit(phase2fluor_wider_model, datamodule=phase2fluor_data)
-
-# %% tags=["solution"]
-
-##########################
-######## Solution ########
-##########################
-
-phase2fluor_slow_model = VSUNet(
- model_config=phase2fluor_config.copy(),
- batch_size=BATCH_SIZE,
- loss_function=torch.nn.functional.l1_loss,
- # lower learning rate by 5 times
- lr=2e-4,
- schedule="WarmupCosine",
- log_num_samples=5,
- example_input_yx_shape=YX_PATCH_SIZE,
-)
-
-trainer = VSTrainer(
- accelerator="gpu",
- devices=[GPU_ID],
- max_epochs=n_epochs,
- log_every_n_steps=steps_per_epoch,
- logger=TensorBoardLogger(
- save_dir=log_dir,
- name="phase2fluor",
- version="low_lr",
- log_graph=True,
- ),
- fast_dev_run=True,
-)
-trainer.fit(phase2fluor_slow_model, datamodule=phase2fluor_data)
-
-
-# %% [markdown] tags=[]
-"""
-
-
-## Checkpoint 3
-
-Congratulations! You have trained several image translation models now!
-Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z). We'll discuss our combined results as a group.
-