Skip to content

Commit

Permalink
- fixing n_channels
Browse files Browse the repository at this point in the history
- adding path to backup trained model
- increased learning rate
  • Loading branch information
edyoshikun committed Aug 25, 2024
1 parent d9c978c commit 9a9e917
Showing 1 changed file with 82 additions and 45 deletions.
127 changes: 82 additions & 45 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,65 @@
# Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco

# ## Overview

# In this exercise, we will predict fluorescence images of
# nuclei and plasma membrane markers from quantitative phase images of cells,
# i.e., we will _virtually stain_ the nuclei and plasma membrane
# visible in the phase image.
# This is an example of an image translation task.
# We will apply spatial and intensity augmentations to train robust models
# and evaluate their performance using a regression approach.

#
# In this exercise, we will _virtually stain_ the nuclei and plasma membrane from the quantitative phase image (QPI), i.e., translate QPI images into fluoresence images of nuclei and plasma membranes.
# QPI encodes multiple cellular structures and virtual staining decomposes these structures. After the model is trained, one only needs to acquire label-free QPI data.
# This strategy solves the problem as "multi-spectral imaging", but is more compatible with live cell imaging and high-throughput screening.
# Virtual staining is often a step towards multiple downstream analyses: segmentation, tracking, and cell state phenotyping.
#
# In this exercise, you will:
# - Train a model to predict the fluorescence images of nuclei and plasma membranes from QPI images
# - Make it robust to variations in imaging conditions using data augmentions
# - Segment the cells
# - Use regression and segmentation metrics to evalute the models
# - Visualize the image transform learned by the model
# - Understand the failure modes of the trained model
#
# [![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755)
# (Click on image to play video)

#
# %% [markdown] tags=[]
# ### Goals

# #### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard.

# - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549),
# each FOV has 3 channels (phase, nuclei, and cell membrane).
# The nuclei were stained with DAPI and the cell membrane with Cellmask.
# #### Part 1: Train a virtual staining model
#
# - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html)
# and the high-content-screen (HCS) format.
# - Use [MONAI](https://monai.io/) to implement data augmentations.

# #### Part 2: Train and evaluate the model to translate phase into fluorescence.
# - Train a 2D UNeXt2 model to predict nuclei and membrane from phase images.
# - Compare the performance of the trained model and a pre-trained model.
# - Use our `viscy.data.HCSDataloader()` dataloader and explore the 3 channel (phase, fluoresecence nuclei and cell membrane)
# A549 cell dataset.
# - Implement data augmentations [MONAI](https://monai.io/) to train a robust model to imaging parameters and conditions.
# - Use tensorboard to log the augmentations, training and validation losses and batches
# - Start the training of the UNeXt2 model to predict nuclei and membrane from phase images.
#
# #### Part 2:Evaluate the model to translate phase into fluorescence.
# - Compare the performance of your trained model with the _VSCyto2D_ pre-trained model.
# - Evaluate the model using pixel-level and instance-level metrics.


# Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos),
#
# #### Part 3: Visualize the image transforms learned by the model and explore the model's regime of validity
# - Visualize the first 3 principal componets mapped to a color space in each encoder and decoder block.
# - Explore the model's regime of validity by applying blurring and scaling transforms to the input phase image.
#
# #### For more information:
# Checkout [VisCy](https://github.com/mehta-lab/VisCy),
# our deep learning pipeline for training and deploying computer vision models
# for image-based phenotyping including the robust virtual staining of landmark organelles.
#
# VisCy exploits recent advances in data and metadata formats
# ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks,
# [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/).

# ### References

# - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf)
# - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502)

# %% [markdown] tags=[]
# <div class="alert alert-info">
# The exercise is organized in 2 parts
# <div class="alert alert-success">
# The exercise is organized in 3 parts:

# <ul>
# <li><b>Part 1</b> - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.</li>
# <li><b>Part 2</b> - Train and evaluate the model to translate phase into fluorescence.</li>
# <li><b>Part 1</b> - Train a virtual staining model using iohub (I/O library), VisCy dataloaders, and tensorboard</li>
# <li><b>Part 2</b> - Evaluate the model to translate phase into fluorescence.</li>
# <li><b>Part 3</b> - Visualize the image transforms learned by the model and explore the model's regime of validity.</li>
# </ul>

# </div>
Expand All @@ -62,7 +72,7 @@
# Set your python kernel to <span style="color:black;">06_image_translation</span>
# </div>
# %% [markdown]
# ## Part 1: Log training data to tensorboard, start training a model.
# # Part 1: Log training data to tensorboard, start training a model.
# ---------
# Learning goals:

Expand Down Expand Up @@ -188,12 +198,14 @@ def launch_tensorboard(log_dir):
# ## Load OME-Zarr Dataset

# There should be 34 FOVs in the dataset.

#
# Each FOV consists of 3 channels of 2048x2048 images,
# saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout)
# specified by the Open Microscopy Environment Next Generation File Format
# (OME-NGFF).

#
# The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask.
#
# - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.`
# - These datasets only have 1 level in the pyramid (highest resolution) which is '0'.

Expand Down Expand Up @@ -344,6 +356,7 @@ def log_batch_jupyter(batch):
p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)

n_channels = batch["target"].shape[1] + batch["source"].shape[1]
plt.figure()
fig, axes = plt.subplots(
batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2)
Expand Down Expand Up @@ -525,10 +538,16 @@ def log_batch_jupyter(batch):

normalizations = [
NormalizeSampled(
keys=source_channel + target_channel,
keys=source_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
NormalizeSampled(
keys=target_channel,
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
]

Expand Down Expand Up @@ -580,10 +599,16 @@ def log_batch_jupyter(batch):

normalizations = [
NormalizeSampled(
keys=source_channel + target_channel,
keys=source_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
NormalizeSampled(
keys=target_channel,
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
]

Expand Down Expand Up @@ -636,7 +661,7 @@ def log_batch_jupyter(batch):
# Create a 2D UNet.
GPU_ID = 0

BATCH_SIZE = 12
BATCH_SIZE = 16
YX_PATCH_SIZE = (256, 256)

# #######################
Expand All @@ -662,7 +687,7 @@ def log_batch_jupyter(batch):
model_config=phase2fluor_config.copy(),
loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
schedule="WarmupCosine",
lr=2e-4,
lr=6e-4,
log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
freeze_encoder=False,
)
Expand Down Expand Up @@ -690,7 +715,7 @@ def log_batch_jupyter(batch):
)
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], precision='16-mixed' ,fast_dev_run=True)

# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
Expand All @@ -701,7 +726,7 @@ def log_batch_jupyter(batch):
# Here we are creating a 2D UNet.
GPU_ID = 0

BATCH_SIZE = 12
BATCH_SIZE = 16
YX_PATCH_SIZE = (256, 256)

# Dictionary that specifies key parameters of the model.
Expand All @@ -724,7 +749,7 @@ def log_batch_jupyter(batch):
model_config=phase2fluor_config.copy(),
loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
schedule="WarmupCosine",
lr=2e-4,
lr=6e-4,
log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
freeze_encoder=False,
)
Expand All @@ -751,7 +776,7 @@ def log_batch_jupyter(batch):
)
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID],precision='16-mixed', fast_dev_run=True)

# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
Expand Down Expand Up @@ -811,12 +836,13 @@ def log_batch_jupyter(batch):

n_samples = len(phase2fluor_2D_data.train_dataset)
steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
n_epochs = 25 # Set this to 25-30 or the number of epochs you want to train for.
n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for.

trainer = VSTrainer(
accelerator="gpu",
devices=[GPU_ID],
max_epochs=n_epochs,
precision='16-mixed',
log_every_n_steps=steps_per_epoch // 2,
# log losses and image samples 2 times per epoch.
logger=TensorBoardLogger(
Expand Down Expand Up @@ -846,7 +872,7 @@ def log_batch_jupyter(batch):

# </div>
# %% [markdown] tags=[]
# ## Part 2: Assess your trained model
# # Part 2: Assess your trained model

# Now we will look at some metrics of performance of previous model.
# We typically evaluate the model performance on a held out test data.
Expand Down Expand Up @@ -897,7 +923,7 @@ def log_batch_jupyter(batch):
#
#```python
#phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt")
#))[-1]
#````
# </div>
Expand Down Expand Up @@ -1084,10 +1110,9 @@ def process_image(image):
# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything.
# Uncomment the next lines
#phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt")
#))[-1]


phase2fluor_config = dict(
in_channels=1,
out_channels=2,
Expand Down Expand Up @@ -1488,6 +1513,12 @@ def min_max_scale(image:ArrayLike)->ArrayLike:
#
# </div>

#%% [markdown] tags=[]
# # Part 3: Visualizing the encoder and decoder features & exploring the model's range of validity
#
# - In this section, we will visualize the encoder and decoder features of the model you trained.
# - We will also explore the model's range of validity by looking at the feature maps of the encoder and decoder.
#
# %% [markdown] tags=[]
# <div class="alert alert-info">
# <h3> Task 3.1: Let's look at what the model is learning </h3>
Expand Down Expand Up @@ -1696,10 +1727,16 @@ def clip_highlight(image: np.ndarray) -> np.ndarray:

normalizations = [
NormalizeSampled(
keys=source_channel + target_channel,
keys=source_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
NormalizeSampled(
keys=target_channel,
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
]

Expand Down

0 comments on commit 9a9e917

Please sign in to comment.